Skip to content

Commit ac003e6

Browse files
committed
added changes to recognize spmd_mode for unit tests
1 parent cdba7b8 commit ac003e6

1 file changed

Lines changed: 8 additions & 0 deletions

File tree

recml/core/training/core.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import abc
1717
from collections.abc import Mapping, Sequence
18+
import contextlib
1819
import dataclasses
1920
import enum
2021
from typing import Any, Generic, TypeVar
@@ -24,6 +25,13 @@
2425
from recml.core.data import iterator
2526
import tensorflow as tf
2627

28+
# Patch jax.spmd_mode if it doesn't exist (removed in newer JAX versions).
29+
if not hasattr(jax, "spmd_mode"):
30+
@contextlib.contextmanager
31+
def _spmd_mode(*args, **kwargs):
32+
del args, kwargs
33+
yield
34+
jax.spmd_mode = _spmd_mode
2735

2836
# pylint: disable=logging-fstring-interpolation
2937

0 commit comments

Comments
 (0)