Skip to content

Commit b4f9573

Browse files
hawkinspGoogle-ML-Automation
authored andcommitted
[JAX] Replace jnp.clip(..., a_min=..., a_max=...) with jnp.clip(..., min=..., max=...).
a_min and a_max are deprecated parameter names to jax.numpy.clip. PiperOrigin-RevId: 890629221
1 parent 2a74af1 commit b4f9573

2 files changed

Lines changed: 10 additions & 4 deletions

File tree

src/maxdiffusion/models/ltx2/attention_ltx2.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,9 @@ def prepare_video_coords(
193193
# pixel_coords[:, 0, ...] selects Frame dimension.
194194
# pixel_coords shape: [B, 3, num_patches, 2] -> dim 1 is (F, H, W)
195195
frame_coords = pixel_coords[:, 0, ...]
196-
frame_coords = jnp.clip(frame_coords + self.causal_offset - self.scale_factors[0], a_min=0)
196+
frame_coords = jnp.clip(
197+
frame_coords + self.causal_offset - self.scale_factors[0], min=0
198+
)
197199
pixel_coords = pixel_coords.at[:, 0, ...].set(frame_coords / fps)
198200

199201
return pixel_coords
@@ -210,12 +212,16 @@ def prepare_audio_coords(
210212
# 2. Start timestamps
211213
audio_scale_factor = self.scale_factors[0]
212214
grid_start_mel = grid_f * audio_scale_factor
213-
grid_start_mel = jnp.clip(grid_start_mel + self.causal_offset - audio_scale_factor, a_min=0)
215+
grid_start_mel = jnp.clip(
216+
grid_start_mel + self.causal_offset - audio_scale_factor, min=0
217+
)
214218
grid_start_s = grid_start_mel * self.hop_length / self.sampling_rate
215219

216220
# 3. End timestamps
217221
grid_end_mel = (grid_f + self.patch_size_t) * audio_scale_factor
218-
grid_end_mel = jnp.clip(grid_end_mel + self.causal_offset - audio_scale_factor, a_min=0)
222+
grid_end_mel = jnp.clip(
223+
grid_end_mel + self.causal_offset - audio_scale_factor, min=0
224+
)
219225
grid_end_s = grid_end_mel * self.hop_length / self.sampling_rate
220226

221227
# Stack [num_patches, 2]

src/maxdiffusion/pipelines/wan/wan_vace_pipeline_2_1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def preprocess_conditions(
188188
if mask is not None:
189189
mask = self.video_processor.preprocess_video(mask, image_size[0], image_size[1])
190190
mask = jnp.array(np.asarray(mask), dtype=video.dtype)
191-
mask = jnp.clip((mask + 1) / 2, a_min=0, a_max=1)
191+
mask = jnp.clip((mask + 1) / 2, min=0, max=1)
192192
else:
193193
mask = jnp.ones_like(video)
194194

0 commit comments

Comments
 (0)