@@ -193,7 +193,9 @@ def prepare_video_coords(
193193 # pixel_coords[:, 0, ...] selects Frame dimension.
194194 # pixel_coords shape: [B, 3, num_patches, 2] -> dim 1 is (F, H, W)
195195 frame_coords = pixel_coords [:, 0 , ...]
196- frame_coords = jnp .clip (frame_coords + self .causal_offset - self .scale_factors [0 ], a_min = 0 )
196+ frame_coords = jnp .clip (
197+ frame_coords + self .causal_offset - self .scale_factors [0 ], min = 0
198+ )
197199 pixel_coords = pixel_coords .at [:, 0 , ...].set (frame_coords / fps )
198200
199201 return pixel_coords
@@ -210,12 +212,16 @@ def prepare_audio_coords(
210212 # 2. Start timestamps
211213 audio_scale_factor = self .scale_factors [0 ]
212214 grid_start_mel = grid_f * audio_scale_factor
213- grid_start_mel = jnp .clip (grid_start_mel + self .causal_offset - audio_scale_factor , a_min = 0 )
215+ grid_start_mel = jnp .clip (
216+ grid_start_mel + self .causal_offset - audio_scale_factor , min = 0
217+ )
214218 grid_start_s = grid_start_mel * self .hop_length / self .sampling_rate
215219
216220 # 3. End timestamps
217221 grid_end_mel = (grid_f + self .patch_size_t ) * audio_scale_factor
218- grid_end_mel = jnp .clip (grid_end_mel + self .causal_offset - audio_scale_factor , a_min = 0 )
222+ grid_end_mel = jnp .clip (
223+ grid_end_mel + self .causal_offset - audio_scale_factor , min = 0
224+ )
219225 grid_end_s = grid_end_mel * self .hop_length / self .sampling_rate
220226
221227 # Stack [num_patches, 2]
0 commit comments