Skip to content

Commit 049ba3f

Browse files
committed
update a_min to min in jnp.clip
1 parent 08e4267 commit 049ba3f

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

tests/utils/forward_pass_logit_checker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -341,8 +341,8 @@ def main(config, test_args): # pylint: disable=W0621
341341
max_logging.log(msg)
342342

343343
if test_args.clip_logits_epsilon is not None:
344-
model_probabilities = jnp.clip(jax.nn.softmax(train_logits_slice, axis=-1), a_min=test_args.clip_logits_epsilon)
345-
golden_probabilities = jnp.clip(jax.nn.softmax(golden_logits_slice, axis=-1), a_min=test_args.clip_logits_epsilon)
344+
model_probabilities = jnp.clip(jax.nn.softmax(train_logits_slice, axis=-1), min=test_args.clip_logits_epsilon)
345+
golden_probabilities = jnp.clip(jax.nn.softmax(golden_logits_slice, axis=-1), min=test_args.clip_logits_epsilon)
346346
else:
347347
model_probabilities = jax.nn.softmax(train_logits_slice, axis=-1)
348348
golden_probabilities = jax.nn.softmax(golden_logits_slice, axis=-1)

0 commit comments

Comments
 (0)