Skip to content

Commit cdba7b8

Browse files
committed
fixed unit tests for paritioning
1 parent 0133de2 commit cdba7b8

2 files changed

Lines changed: 8 additions & 3 deletions

File tree

recml/core/training/partitioning_test.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ def test_data_parallelism(
4040
self, partitioner_cls: type[partitioning.Partitioner]
4141
):
4242
if partitioner_cls is partitioning.ModelParallelPartitioner:
43-
kwargs = {"axes": [("data", -1), ("model", 1)], "dp_axes": 1}
43+
devs = np.array(jax.devices()).reshape(-1, 1)
44+
kwargs = {"axes": [("data", -1), ("model", 1)], "dp_axes": 1, "devices": devs}
4445
else:
4546
kwargs = {}
4647
partitioner = partitioner_cls(**kwargs)
@@ -112,8 +113,12 @@ def _eval_step(
112113
)
113114

114115
def test_model_parallelism(self):
116+
devs = np.array(jax.devices()).reshape(1, -1)
117+
115118
partitioner = partitioning.ModelParallelPartitioner(
116-
axes=[("data", 1), ("model", jax.device_count())], dp_axes=1
119+
axes=[("data", 1), ("model", jax.device_count())],
120+
dp_axes=1,
121+
devices=devs
117122
)
118123

119124
inputs = np.zeros((128, 16), dtype=np.float32)

recml/examples/train_hstu_jax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def experiment() -> fdl.Config[recml.Experiment]:
198198
"""Defines the experiment structure using Fiddle configs"""
199199

200200
max_seq_len = 50
201-
batch_size = 128
201+
batch_size = 64
202202

203203
model_cfg = fdl.Config(
204204
HSTUModelConfig,

0 commit comments

Comments
 (0)