Skip to content

Commit ffbc36b

Browse files
committed
updated code to run both HSTU and DLRM on v6
1 parent 1df0323 commit ffbc36b

2 files changed

Lines changed: 42 additions & 14 deletions

File tree

recml/examples/dlrm_experiment.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,12 @@
1919
import dataclasses
2020
from typing import Generic, Literal, TypeVar
2121

22+
import sys
23+
import os
24+
# Add the RecML folder to the system path
25+
sys.path.append(os.path.join(os.getcwd(), "../../../RecML"))
26+
os.environ["KERAS_BACKEND"] = "jax"
27+
2228
from etils import epy
2329
import fiddle as fdl
2430
import flax.linen as nn

requirements.txt

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
absl-py==2.2.2
2+
aiofiles==25.1.0
3+
array-record==0.8.3
24
astroid==3.3.9
35
astunparse==1.6.3
46
attrs==25.3.0
@@ -16,7 +18,7 @@ etils==1.12.2
1618
fiddle==0.3.0
1719
filelock==3.18.0
1820
flatbuffers==25.2.10
19-
flax==0.10.5
21+
flax==0.12.2
2022
fsspec==2025.3.2
2123
gast==0.6.0
2224
google-pasta==0.2.0
@@ -31,18 +33,22 @@ immutabledict==4.2.1
3133
importlib-resources==6.5.2
3234
iniconfig==2.1.0
3335
isort==6.0.1
34-
jax==0.6.0
35-
jaxlib==0.6.0
36+
jax==0.8.2
37+
jax-tpu-embedding==0.1.0.dev20251208
38+
jaxlib==0.8.2
3639
jaxtyping==0.3.1
37-
jinja2==3.1.6
40+
Jinja2==3.1.6
3841
kagglehub==0.3.11
3942
keras==3.9.2
4043
keras-hub==0.20.0
4144
libclang==18.1.1
4245
libcst==1.7.0
43-
markdown==3.8
46+
libtpu==0.0.32
47+
# libtpu-nightly is usually installed directly via URL, but pinning it helps tracking
48+
# libtpu-nightly==0.1.dev20240617+default
49+
Markdown==3.8
4450
markdown-it-py==3.0.0
45-
markupsafe==3.0.2
51+
MarkupSafe==3.0.2
4652
mccabe==0.7.0
4753
mdurl==0.1.2
4854
ml-collections==1.1.0
@@ -54,23 +60,37 @@ nest-asyncio==1.6.0
5460
networkx==3.4.2
5561
nodeenv==1.9.1
5662
numpy==2.1.3
63+
nvidia-cublas-cu12==12.4.5.8
64+
nvidia-cuda-cupti-cu12==12.4.127
65+
nvidia-cuda-nvrtc-cu12==12.4.127
66+
nvidia-cuda-runtime-cu12==12.4.127
67+
nvidia-cudnn-cu12==9.1.0.70
68+
nvidia-cufft-cu12==11.2.1.3
69+
nvidia-curand-cu12==10.3.5.147
70+
nvidia-cusolver-cu12==11.6.1.9
71+
nvidia-cusparse-cu12==12.3.1.170
72+
nvidia-cusparselt-cu12==0.6.2
73+
nvidia-nccl-cu12==2.21.5
74+
nvidia-nvjitlink-cu12==12.4.127
75+
nvidia-nvtx-cu12==12.4.127
5776
opt-einsum==3.4.0
5877
optax==0.2.4
5978
optree==0.15.0
60-
orbax-checkpoint==0.11.12
79+
orbax-checkpoint==0.11.31
6180
packaging==24.2
6281
platformdirs==4.3.7
6382
pluggy==1.5.0
83+
portpicker==1.6.0
6484
pre-commit==4.2.0
6585
promise==2.3
66-
# protobuf==5.29.4
86+
# protobuf==6.33.4
6787
psutil==7.0.0
6888
pyarrow==19.0.1
69-
pygments==2.19.1
89+
Pygments==2.19.1
7090
pylint==3.3.6
7191
pytest==8.3.5
7292
pytest-env==1.1.5
73-
pyyaml==6.0.2
93+
PyYAML==6.0.2
7494
regex==2024.11.6
7595
requests==2.32.3
7696
rich==14.0.0
@@ -84,21 +104,23 @@ tensorboard==2.19.0
84104
tensorboard-data-server==0.7.2
85105
tensorflow==2.19.0
86106
tensorflow-datasets==4.9.8
107+
tensorflow-io-gcs-filesystem==0.37.1
87108
tensorflow-metadata==1.17.1
88109
tensorflow-text==2.19.0
89-
tensorstore==0.1.73
110+
tensorstore==0.1.80
90111
termcolor==3.0.1
91112
toml==0.10.2
92113
tomlkit==0.13.2
93114
toolz==1.0.0
94115
torch==2.6.0
95116
tqdm==4.67.1
96117
treescope==0.1.9
97-
typing-extensions==4.13.2
118+
triton==3.2.0
119+
typing_extensions==4.13.2
98120
urllib3==2.4.0
99121
virtualenv==20.30.0
100122
wadler-lindig==0.1.5
101-
werkzeug==3.1.3
123+
Werkzeug==3.1.3
102124
wheel==0.45.1
103125
wrapt==1.17.2
104-
zipp==3.21.0
126+
zipp==3.21.0

0 commit comments

Comments
 (0)