We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent cdba7b8 commit ac003e6Copy full SHA for ac003e6
1 file changed
recml/core/training/core.py
@@ -15,6 +15,7 @@
15
16
import abc
17
from collections.abc import Mapping, Sequence
18
+import contextlib
19
import dataclasses
20
import enum
21
from typing import Any, Generic, TypeVar
@@ -24,6 +25,13 @@
24
25
from recml.core.data import iterator
26
import tensorflow as tf
27
28
+# Patch jax.spmd_mode if it doesn't exist (removed in newer JAX versions).
29
+if not hasattr(jax, "spmd_mode"):
30
+ @contextlib.contextmanager
31
+ def _spmd_mode(*args, **kwargs):
32
+ del args, kwargs
33
+ yield
34
+ jax.spmd_mode = _spmd_mode
35
36
# pylint: disable=logging-fstring-interpolation
37
0 commit comments