Skip to content

Commit 5023bd5

Browse files
authored
Merge pull request #149 from AI-Hypercomputer/ajkv/multi-host-v7-training
Added multihost training for V7
2 parents fdb73b4 + e97f4fb commit 5023bd5

5 files changed

Lines changed: 59 additions & 50 deletions

File tree

Dockerfile

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,26 @@ WORKDIR /app
77
# This tells Python to look in /app for the 'recml' package
88
ENV PYTHONPATH="${PYTHONPATH}:/app"
99

10-
# Install system tools if needed (e.g., git)
11-
RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/*
10+
# This prevents the "MessageFactory" crash when using Protobuf
11+
ENV PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
1212

13-
# Install the latest jax-tpu-embedding wheel
14-
COPY jax_tpu_embedding-0.1.0.dev20260121-cp312-cp312-manylinux_2_31_x86_64.whl ./
15-
RUN pip install ./jax_tpu_embedding-0.1.0.dev20260121-cp312-cp312-manylinux_2_31_x86_64.whl
13+
# This prevents the "Unable to register cuFFT/cuBLAS" log spam and initialization errors
14+
ENV CUDA_VISIBLE_DEVICES=-1
1615

17-
# Copy requirements.txt to current directory
18-
COPY requirements.txt ./
16+
# Install system tools
17+
RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/*
1918

20-
# Install dependencies
19+
# Install standard requirements
20+
COPY requirements.txt ./
2121
RUN pip install --upgrade pip
2222
RUN pip install -r ./requirements.txt
2323

2424
# Force install the specific protobuf version
25-
RUN pip install "protobuf>=6.31.1" --no-deps
25+
RUN pip install "protobuf>=6.31.1"
26+
27+
# Install the latest jax-tpu-embedding wheel
28+
COPY jax_tpu_embedding-0.1.0.dev20260121-cp312-cp312-manylinux_2_31_x86_64.whl ./
29+
RUN pip install ./jax_tpu_embedding-0.1.0.dev20260121-cp312-cp312-manylinux_2_31_x86_64.whl
2630

2731
# Copy the current directory contents into the container
2832
COPY . /app

recml/core/training/jax_trainer.py

Lines changed: 41 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,7 @@ def __init__(
398398
continuous_eval_timeout: int = 30,
399399
rng_seed: int = core.DEFAULT_RNG_SEED,
400400
rng_impl: str | None = None,
401+
enable_checkpointing: bool = True,
401402
):
402403
"""Initializes the instance.
403404
@@ -436,6 +437,7 @@ def __init__(
436437
rng_impl: The implementation of the PRNG key. By default this is set to
437438
None which means that the default implementation (generally
438439
partitionable threefry) will be used.
440+
enable_checkpointing: Whether to enable checkpointing. Defaults to True.
439441
"""
440442

441443
if not isinstance(steps_per_loop, int) or steps_per_loop < 1:
@@ -453,6 +455,7 @@ def __init__(
453455
self._max_checkpoints_to_keep = max_checkpoints_to_keep
454456
self._rng_impl = rng_impl
455457
self._rng_seed = rng_seed
458+
self._enable_checkpointing = enable_checkpointing
456459

457460
@functools.cached_property
458461
def checkpoint_manager(self) -> ocp.CheckpointManager:
@@ -467,14 +470,19 @@ def checkpoint_manager(self) -> ocp.CheckpointManager:
467470
save_on_steps.append(self._train_steps - 1)
468471

469472
save_on_steps = set(save_on_steps)
470-
471-
return ocp.CheckpointManager(
472-
directory=os.path.join(self._model_dir, core.CHECKPOINT_DIR),
473-
options=ocp.CheckpointManagerOptions(
474-
should_save_fn=lambda step, _: step in save_on_steps,
475-
max_to_keep=self._max_checkpoints_to_keep,
476-
),
477-
)
473+
474+
if self._enable_checkpointing:
475+
476+
return ocp.CheckpointManager(
477+
directory=os.path.join(self._model_dir, core.CHECKPOINT_DIR),
478+
options=ocp.CheckpointManagerOptions(
479+
should_save_fn=lambda step, _: step in save_on_steps,
480+
max_to_keep=self._max_checkpoints_to_keep,
481+
),
482+
)
483+
else:
484+
485+
return None
478486

479487
@functools.cached_property
480488
def train_summary_writer(self) -> metrics_tools.AsyncMultiWriter:
@@ -510,6 +518,9 @@ def _maybe_save_checkpoint(
510518
metrics: Mapping[str, Any] | None = None,
511519
):
512520
"""Saves a checkpoint and returns a bool indicating whether it was saved."""
521+
if not self._enable_checkpointing:
522+
return
523+
513524
items = {core.STATE_CHECKPOINT_KEY: ocp.args.StandardSave(state)}
514525
with self.report_progress.timed("checkpointing"):
515526
self.checkpoint_manager.save(
@@ -564,7 +575,7 @@ def _train_n_steps(
564575
state, metrics_update = train_step(inputs, state)
565576
metrics_accum.accumulate(metrics_update, step)
566577
self.report_progress(step)
567-
if step != start_step + num_steps - 1:
578+
if (step != start_step + num_steps - 1) and self._enable_checkpointing:
568579
self._maybe_save_checkpoint(step, state)
569580

570581
metrics = metrics_accum.compute_and_log_scalars(start_step + num_steps - 1)
@@ -651,6 +662,7 @@ def _eval_step(
651662

652663
if (
653664
check_for_checkpoints
665+
and self._enable_checkpointing
654666
and self.checkpoint_manager.latest_step() is not None
655667
):
656668
step_to_resume_from = self.checkpoint_manager.latest_step()
@@ -674,7 +686,7 @@ def _eval_step(
674686
def train(self, task: JaxTask) -> core.Logs:
675687
"""Trains the model."""
676688
train_iter, _, state, train_step, _, step = self.process_task(
677-
task, training=True, check_for_checkpoints=True
689+
task, training=True, check_for_checkpoints=False
678690
)
679691

680692
logging.info(
@@ -698,25 +710,27 @@ def train(self, task: JaxTask) -> core.Logs:
698710
f" {_format_output(train_metrics)}"
699711
)
700712
metrics[core.TRAIN_LOG_DIRNAME] = train_metrics
701-
702-
self._maybe_save_checkpoint(curr_step, state, metrics=metrics)
713+
if self._enable_checkpointing:
714+
self._maybe_save_checkpoint(curr_step, state, metrics=metrics)
703715
step = curr_step + 1
704716

705-
self.checkpoint_manager.wait_until_finished()
717+
if self._enable_checkpointing:
718+
self.checkpoint_manager.wait_until_finished()
706719

707720
if jax.process_index() == 0:
708721
self._write_marker_file()
709722
task.export_model(state, self._model_dir)
710723

711-
self.checkpoint_manager.close()
712-
del self.checkpoint_manager
724+
if self._enable_checkpointing:
725+
self.checkpoint_manager.close()
726+
del self.checkpoint_manager
713727

714728
return metrics
715729

716730
def evaluate(self, task: JaxTask) -> core.Logs:
717731
"""Evaluates the model."""
718732
_, eval_iters, state, _, eval_step, step = self.process_task(
719-
task, training=False, check_for_checkpoints=True
733+
task, training=False, check_for_checkpoints=False
720734
)
721735
eval_summary_writers = self._create_eval_summary_writers(eval_iters)
722736

@@ -749,7 +763,7 @@ def evaluate(self, task: JaxTask) -> core.Logs:
749763
def train_and_evaluate(self, task: JaxTask) -> core.Logs:
750764
"""Trains and evaluates the model."""
751765
train_iter, eval_iters, state, train_step, eval_step, step = (
752-
self.process_task(task, training=True, check_for_checkpoints=True)
766+
self.process_task(task, training=True, check_for_checkpoints=False)
753767
)
754768
eval_summary_writers = self._create_eval_summary_writers(eval_iters)
755769

@@ -794,18 +808,20 @@ def train_and_evaluate(self, task: JaxTask) -> core.Logs:
794808
f" {_format_output(eval_metrics)}"
795809
)
796810
metrics[_val_logdir(key)] = eval_metrics
797-
798-
self._maybe_save_checkpoint(curr_step, state, metrics=metrics)
811+
if self._enable_checkpointing:
812+
self._maybe_save_checkpoint(curr_step, state, metrics=metrics)
799813
step = curr_step + 1
800814

801-
self.checkpoint_manager.wait_until_finished()
815+
if self._enable_checkpointing:
816+
self.checkpoint_manager.wait_until_finished()
802817

803818
if jax.process_index() == 0:
804819
self._write_marker_file()
805820
task.export_model(state, self._model_dir)
806821

807-
self.checkpoint_manager.close()
808-
del self.checkpoint_manager
822+
if self._enable_checkpointing:
823+
self.checkpoint_manager.close()
824+
del self.checkpoint_manager
809825

810826
return metrics
811827

@@ -833,7 +849,8 @@ def timeout_fn() -> bool:
833849
timeout_fn=timeout_fn,
834850
):
835851
try:
836-
state = self._maybe_restore_checkpoint(state, step)
852+
if self._enable_checkpointing:
853+
state = self._maybe_restore_checkpoint(state, step)
837854
logging.info(f"eval | step: {step: 6d} | {steps_msg}")
838855
with self.report_progress.timed("eval"):
839856
for key, eval_iter in eval_iters.items():
@@ -930,3 +947,4 @@ def _format_output(output: Any, indent: int = 4, width: int = 80) -> str:
930947
return formatted
931948
lines = [" " * indent + line for line in lines]
932949
return "\n" + "\n".join(lines)
950+

recml/examples/dlrm_experiment.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,5 +398,6 @@ def experiment() -> fdl.Config[recml.Experiment]:
398398
train_steps=1_000,
399399
steps_per_eval=100,
400400
steps_per_loop=100,
401+
enable_checkpointing=False
401402
)
402403
return fdl.Config(recml.Experiment, task=task, trainer=trainer)

recml/examples/dlrm_experiment_test.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# limitations under the License.
1414
"""Tests for the DLRM experiment."""
1515

16+
import jax
17+
from absl import logging
1618
import sys
1719
import os
1820
# Add the RecML folder to the system path
@@ -43,6 +45,7 @@ def test_dlrm_experiment(self):
4345
experiment.trainer.train_steps = 12
4446
experiment.trainer.steps_per_loop = 4
4547
experiment.trainer.steps_per_eval = 4
48+
experiment.trainer.enable_checkpointing = False
4649

4750
for cfg in selectors.select(experiment, dlrm_experiment.SparseFeature):
4851
cfg.vocab_size = 200
@@ -53,4 +56,4 @@ def test_dlrm_experiment(self):
5356

5457

5558
if __name__ == "__main__":
56-
absltest.main()
59+
absltest.main()

requirements.txt

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ etils==1.12.2
1818
fiddle==0.3.0
1919
filelock==3.18.0
2020
flatbuffers==25.2.10
21-
flax==0.12.2
2221
fsspec==2025.3.2
2322
gast==0.6.0
2423
google-pasta==0.2.0
@@ -33,8 +32,6 @@ immutabledict==4.2.1
3332
importlib-resources==6.5.2
3433
iniconfig==2.1.0
3534
isort==6.0.1
36-
jax==0.8.2
37-
jaxlib==0.8.2
3835
jaxtyping==0.3.1
3936
Jinja2==3.1.6
4037
kagglehub==0.3.11
@@ -59,19 +56,6 @@ nest-asyncio==1.6.0
5956
networkx==3.4.2
6057
nodeenv==1.9.1
6158
numpy==2.1.3
62-
nvidia-cublas-cu12==12.4.5.8
63-
nvidia-cuda-cupti-cu12==12.4.127
64-
nvidia-cuda-nvrtc-cu12==12.4.127
65-
nvidia-cuda-runtime-cu12==12.4.127
66-
nvidia-cudnn-cu12==9.1.0.70
67-
nvidia-cufft-cu12==11.2.1.3
68-
nvidia-curand-cu12==10.3.5.147
69-
nvidia-cusolver-cu12==11.6.1.9
70-
nvidia-cusparse-cu12==12.3.1.170
71-
nvidia-cusparselt-cu12==0.6.2
72-
nvidia-nccl-cu12==2.21.5
73-
nvidia-nvjitlink-cu12==12.4.127
74-
nvidia-nvtx-cu12==12.4.127
7559
opt-einsum==3.4.0
7660
optax==0.2.4
7761
optree==0.15.0
@@ -82,7 +66,6 @@ pluggy==1.5.0
8266
portpicker==1.6.0
8367
pre-commit==4.2.0
8468
promise==2.3
85-
# protobuf==6.33.4
8669
psutil==7.0.0
8770
pyarrow==19.0.1
8871
Pygments==2.19.1

0 commit comments

Comments
 (0)