Skip to content

Commit 1df0323

Browse files
committed
Updated DLRM experiment to work with V6 chip
1 parent f59d184 commit 1df0323

7 files changed

Lines changed: 82 additions & 33 deletions

File tree

recml/core/data/iterator.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from etils import epath
2222
import numpy as np
2323
import tensorflow as tf
24+
import jax
2425

2526

2627
Iterator = clu_data.DatasetIterator
@@ -67,17 +68,22 @@ def _maybe_to_numpy(
6768
) -> np.ndarray | tf.SparseTensor | tf.RaggedTensor:
6869
if isinstance(x, (tf.SparseTensor, tf.RaggedTensor, np.ndarray)):
6970
return x
71+
# FIX: Check for attribute existence to avoid crashes on non-Tensor objects
7072
if hasattr(x, "_numpy"):
7173
numpy = x._numpy() # pylint: disable=protected-access
72-
else:
74+
elif hasattr(x, "numpy"):
7375
numpy = x.numpy()
76+
else:
77+
return x # Return as-is if it can't be converted
78+
7479
if isinstance(numpy, np.ndarray):
7580
# `numpy` shares the same underlying buffer as the `x` Tensor.
7681
# Tensors are expected to be immutable, so we disable writes.
7782
numpy.setflags(write=False)
7883
return numpy
7984

80-
return tf.nest.map_structure(_maybe_to_numpy, batch)
85+
# FIX: Use jax.tree.map instead of tf.nest.map_structure
86+
return jax.tree.map(_maybe_to_numpy, batch)
8187

8288
@property
8389
def element_spec(self) -> clu_data.ElementSpec:
@@ -109,7 +115,8 @@ def _to_element_spec(
109115
)
110116
return clu_data.ArraySpec(dtype=x.dtype, shape=tuple(x.shape))
111117

112-
element_spec = tf.nest.map_structure(_to_element_spec, batch)
118+
# element_spec = tf.nest.map_structure(_to_element_spec, batch)
119+
element_spec = jax.tree.map(_to_element_spec, batch)
113120
self._element_spec = element_spec
114121
return element_spec
115122

recml/core/ops/embedding_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class SparsecoreParams:
3838
"""Embedding parameters."""
3939

4040
feature_specs: Nested[FeatureSpec]
41-
mesh: jax.sharding.Mesh | jax.sharding.AbstractMesh
41+
mesh: jax.sharding.Mesh
4242
data_axes: Sequence[str | None]
4343
embedding_axes: Sequence[str | None]
4444
sharding_strategy: str
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
"""Simple global context for mesh, replacing jax.experimental.maps."""
2+
3+
_GLOBAL_MESH = None
4+
5+
def set_global_mesh(mesh):
6+
global _GLOBAL_MESH
7+
_GLOBAL_MESH = mesh
8+
9+
def get_global_mesh():
10+
return _GLOBAL_MESH

recml/core/training/partitioning.py

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
1415
"""Utilities for partitioning."""
1516

1617
import abc
@@ -21,6 +22,8 @@
2122
import flax.linen as nn
2223
import jax
2324
import numpy as np
25+
# FIXED: Use the public experimental module available in JAX 0.4.30
26+
from recml.core.training import mesh_context
2427

2528

2629
PyTree = Any
@@ -67,7 +70,8 @@ class DataParallelPartitioner(Partitioner):
6770
"""Data parallel partitioner."""
6871

6972
def __init__(self, data_axis: str = "batch"):
70-
self.mesh = jax.make_mesh((jax.device_count(),), (data_axis,))
73+
devices = jax.devices()
74+
self.mesh = jax.sharding.Mesh(devices, (data_axis,))
7175
self.data_sharding = jax.sharding.NamedSharding(
7276
self.mesh, jax.sharding.PartitionSpec(data_axis)
7377
)
@@ -107,8 +111,10 @@ def _shard(x: np.ndarray) -> jax.Array:
107111
def partition_init(
108112
self, init_fn: CreateStateFn, *, abstract_batch: PyTree | None = None
109113
) -> CreateStateFn:
110-
with jax.sharding.use_mesh(self.mesh):
114+
# FIXED: Use 'with self.mesh'
115+
with self.mesh:
111116
if abstract_batch is not None:
117+
mesh_context.set_global_mesh(self.mesh)
112118
abstract_state = jax.eval_shape(init_fn, abstract_batch)
113119
specs = nn.get_partition_spec(abstract_state)
114120
self.state_sharding = jax.tree.map(
@@ -117,7 +123,8 @@ def partition_init(
117123
init_fn = jax.jit(init_fn, out_shardings=self.state_sharding)
118124

119125
def _wrapped_init(batch: PyTree) -> State:
120-
with jax.sharding.use_mesh(self.mesh):
126+
# FIXED: Use 'with self.mesh'
127+
with self.mesh:
121128
state = init_fn(batch)
122129
state = _maybe_unbox_state(state)
123130
return state
@@ -130,15 +137,18 @@ def partition_step(self, fn: StepFn, *, training: bool = False) -> StepFn:
130137
jit_kws["out_shardings"] = (self.state_sharding, None)
131138
jit_kws["donate_argnums"] = (1,)
132139

133-
with jax.sharding.use_mesh(self.mesh):
140+
# FIXED: Use 'with self.mesh' and legacy bridge
141+
with self.mesh:
142+
mesh_context.set_global_mesh(self.mesh)
134143
step_fn = jax.jit(
135144
fn,
136145
in_shardings=(self.data_sharding, self.state_sharding),
137146
**jit_kws,
138147
)
139148

140149
def _wrapped_step(batch: PyTree, state: State) -> Any:
141-
with jax.sharding.use_mesh(self.mesh):
150+
# FIXED: Use 'with self.mesh'
151+
with self.mesh:
142152
return step_fn(batch, state)
143153

144154
return _wrapped_step
@@ -190,7 +200,8 @@ def __init__(
190200
if axis_sizes[0] == -1:
191201
axis_sizes[0] = len(devices) // math.prod(axis_sizes[1:])
192202

193-
self.mesh = jax.make_mesh(axis_sizes, axis_names, devices=devices)
203+
# self.mesh = jax.make_mesh(axis_sizes, axis_names, devices=devices)
204+
self.mesh = jax.sharding.Mesh(devices, axis_names)
194205
self.rules = rules
195206
self.aot_compile = aot_compile
196207
self.options = options
@@ -213,12 +224,6 @@ def __init__(
213224
self.abstract_batch = None
214225
self.abstract_state = None
215226

216-
@property
217-
def mesh_context_manager(
218-
self,
219-
) -> Callable[[jax.sharding.Mesh], ContextManager[None]]:
220-
return jax.sharding.use_mesh
221-
222227
def shard_inputs(self, inputs: PyTree) -> PyTree:
223228
def _shard(x: np.ndarray) -> jax.Array:
224229
return jax.make_array_from_process_local_data(self.data_sharding, x)
@@ -234,7 +239,10 @@ def partition_init(
234239
" model parallel partitioner."
235240
)
236241

237-
with self.mesh_context_manager(self.mesh):
242+
# FIXED: Use 'with self.mesh' directly
243+
with self.mesh:
244+
# FIXED: Legacy bridge
245+
mesh_context.set_global_mesh(self.mesh)
238246
abstract_state = jax.eval_shape(init_fn, abstract_batch)
239247
specs = nn.get_partition_spec(abstract_state)
240248

@@ -247,7 +255,8 @@ def partition_init(
247255
compiled_init_fn = jax.jit(init_fn, out_shardings=state_sharding)
248256

249257
def _init(batch: PyTree) -> State:
250-
with self.mesh_context_manager(self.mesh):
258+
# FIXED: Use 'with self.mesh' directly
259+
with self.mesh:
251260
state = compiled_init_fn(batch)
252261
state = _maybe_unbox_state(state)
253262
return state
@@ -265,7 +274,9 @@ def partition_step(self, fn: StepFn, *, training: bool = False) -> StepFn:
265274
else:
266275
jit_kws["out_shardings"] = None
267276

268-
with self.mesh_context_manager(self.mesh):
277+
# FIXED: Use 'with self.mesh' directly and legacy bridge
278+
with self.mesh:
279+
mesh_context.set_global_mesh(self.mesh)
269280
step_fn = jax.jit(
270281
fn,
271282
in_shardings=(self.data_sharding, self.state_sharding),
@@ -286,7 +297,8 @@ def partition_step(self, fn: StepFn, *, training: bool = False) -> StepFn:
286297
)
287298

288299
def _step(batch: PyTree, state: State) -> Any:
289-
with self.mesh_context_manager(self.mesh):
300+
# FIXED: Use 'with self.mesh' directly
301+
with self.mesh:
290302
return step_fn(batch, state)
291303

292304
return _step
@@ -302,4 +314,4 @@ def _maybe_unbox(x: Any) -> Any:
302314
_maybe_unbox,
303315
x,
304316
is_leaf=lambda k: isinstance(k, nn.Partitioned),
305-
)
317+
)

recml/examples/dlrm_experiment_test.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@
1313
# limitations under the License.
1414
"""Tests for the DLRM experiment."""
1515

16+
import sys
17+
import os
18+
# Add the RecML folder to the system path
19+
sys.path.append(os.path.join(os.getcwd(), "../../../RecML"))
20+
os.environ["KERAS_BACKEND"] = "jax"
21+
1622
from absl.testing import absltest
1723
import fiddle as fdl
1824
from fiddle import selectors
@@ -32,8 +38,8 @@ def test_dlrm_experiment(self):
3238

3339
experiment = dlrm_experiment.experiment()
3440

35-
experiment.task.train_data.global_batch_size = 4
36-
experiment.task.eval_data.global_batch_size = 4
41+
experiment.task.train_data.global_batch_size = 128
42+
experiment.task.eval_data.global_batch_size = 128
3743
experiment.trainer.train_steps = 12
3844
experiment.trainer.steps_per_loop = 4
3945
experiment.trainer.steps_per_eval = 4

recml/layers/linen/sparsecore.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,12 @@
2828
from recml.core.ops import embedding_ops
2929
import tensorflow as tf
3030

31+
from recml.core.training import mesh_context
32+
3133

3234
with epy.lazy_imports():
3335
# pylint: disable=g-import-not-at-top
34-
from jax_tpu_embedding.sparsecore.lib.flax import embed
36+
from jax_tpu_embedding.sparsecore.lib.flax.linen import embed
3537
from jax_tpu_embedding.sparsecore.lib.nn import embedding
3638
from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec
3739
from jax_tpu_embedding.sparsecore.lib.nn import table_stacking
@@ -369,16 +371,28 @@ class SparsecoreEmbed(nn.Module):
369371
sparsecore_config: SparsecoreConfig
370372
mesh: jax.sharding.Mesh | jax.sharding.AbstractMesh | None = None
371373

372-
def get_mesh(self) -> jax.sharding.Mesh | jax.sharding.AbstractMesh:
373-
if self.mesh is not None:
374-
return self.mesh
375-
abstract_mesh = jax.sharding.get_abstract_mesh()
376-
if not abstract_mesh.shape_tuple:
374+
# def get_mesh(self) -> jax.sharding.Mesh | jax.sharding.AbstractMesh:
375+
# if self.mesh is not None:
376+
# return self.mesh
377+
# abstract_mesh = jax.sharding.get_abstract_mesh()
378+
# if not abstract_mesh.shape_tuple:
379+
# raise ValueError(
380+
# 'No abstract mesh shape was set with `jax.sharding.use_mesh`. Make'
381+
# ' sure to set the mesh when calling the sparsecore module.'
382+
# )
383+
# return abstract_mesh
384+
385+
def get_mesh(self) -> jax.sharding.Mesh:
386+
# Try to get the mesh from our custom global context
387+
mesh = mesh_context.get_global_mesh()
388+
389+
if mesh is None:
377390
raise ValueError(
378-
'No abstract mesh shape was set with `jax.sharding.use_mesh`. Make'
379-
' sure to set the mesh when calling the sparsecore module.'
391+
"No global mesh found. Make sure to call "
392+
"`partitioning.partition_init` (which sets the mesh) "
393+
"before initializing SparseCore."
380394
)
381-
return abstract_mesh
395+
return mesh
382396

383397
def get_sharding_axis(
384398
self, mesh: jax.sharding.Mesh | jax.sharding.AbstractMesh

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ platformdirs==4.3.7
6363
pluggy==1.5.0
6464
pre-commit==4.2.0
6565
promise==2.3
66-
protobuf==5.29.4
66+
# protobuf==5.29.4
6767
psutil==7.0.0
6868
pyarrow==19.0.1
6969
pygments==2.19.1

0 commit comments

Comments
 (0)