Skip to content

Commit 3ac667a

Browse files
Kevin WangGoogle-ML-Automation
authored andcommitted
Only pass in logical axis rules if present in config.
PiperOrigin-RevId: 898044030
1 parent 360cd5a commit 3ac667a

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

src/maxtext/layers/embeddings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def __call__(self, inputs: Array, model_mode: str = MODEL_MODE_TRAIN) -> Array:
169169
"activation_embed",
170170
)
171171
)
172-
out_pspec = logical_to_mesh_axes(output_axis_names, self.mesh, rules=self.config.logical_axis_rules)
172+
out_pspec = logical_to_mesh_axes(output_axis_names, self.mesh, rules=getattr(self.config, "logical_axis_rules", None))
173173

174174
out_sharding = NamedSharding(self.mesh, out_pspec) if self.config.shard_mode == ShardMode.EXPLICIT else None
175175

0 commit comments

Comments
 (0)