Skip to content

Commit 172ee51

Browse files
committed
Added training scripts for HSTU using keras and jax trainers
1 parent 61c08ca commit 172ee51

2 files changed

Lines changed: 494 additions & 0 deletions

File tree

recml/examples/train_hstu_jax.py

Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
"""HSTU Experiment Configuration using Fiddle and RecML with JaxTrainer"""
2+
3+
import dataclasses
4+
from typing import Mapping, Tuple
5+
import sys
6+
import os
7+
8+
os.environ["KERAS_BACKEND"] = "jax"
9+
10+
import fiddle as fdl
11+
import jax
12+
import jax.numpy as jnp
13+
import keras
14+
import optax
15+
import tensorflow as tf
16+
import clu.metrics as clu_metrics
17+
from absl import app
18+
from absl import flags
19+
from absl import logging
20+
21+
# Add the RecML folder to the system path
22+
sys.path.append(os.path.join(os.getcwd(), "../../../RecML"))
23+
24+
# RecML Imports
25+
from recml.core.training import core
26+
from recml.core.training import jax_trainer
27+
from recml.core.training import partitioning
28+
from recml.layers.keras import hstu
29+
import recml
30+
31+
# Define command-line flags
32+
FLAGS = flags.FLAGS
33+
34+
flags.DEFINE_string("train_path", None, "Path (or pattern) to training data")
35+
flags.DEFINE_string("eval_path", None, "Path (or glob pattern) to evaluation data")
36+
37+
flags.DEFINE_string("model_dir", "/tmp/hstu_model_jax", "Where to save the model")
38+
flags.DEFINE_integer("vocab_size", 5_000_000, "Vocabulary size")
39+
flags.DEFINE_integer("train_steps", 2000, "Total training steps")
40+
41+
# Mark flags as required
42+
flags.mark_flag_as_required("train_path")
43+
flags.mark_flag_as_required("eval_path")
44+
45+
@dataclasses.dataclass
46+
class HSTUModelConfig:
47+
"""Configuration for the HSTU model architecture"""
48+
vocab_size: int = 5_000_000
49+
max_sequence_length: int = 50
50+
model_dim: int = 64
51+
num_heads: int = 4
52+
num_layers: int = 4
53+
dropout: float = 0.5
54+
learning_rate: float = 1e-3
55+
56+
class TFRecordDataFactory(recml.Factory[tf.data.Dataset]):
57+
"""Reusable Data Factory for TFRecord datasets"""
58+
59+
path: str
60+
batch_size: int
61+
max_sequence_length: int
62+
feature_key: str = "sequence"
63+
target_key: str = "target"
64+
is_training: bool = True
65+
66+
def make(self) -> tf.data.Dataset:
67+
"""Builds the tf.data.Dataset"""
68+
if not self.path:
69+
logging.warning("No path provided for dataset factory")
70+
return tf.data.Dataset.empty()
71+
72+
dataset = tf.data.Dataset.list_files(self.path)
73+
dataset = tf.data.TFRecordDataset(dataset, num_parallel_reads=tf.data.AUTOTUNE)
74+
75+
def _parse_fn(serialized_example):
76+
features = {
77+
self.feature_key: tf.io.VarLenFeature(tf.int64),
78+
self.target_key: tf.io.FixedLenFeature([1], tf.int64),
79+
}
80+
parsed = tf.io.parse_single_example(serialized_example, features)
81+
82+
seq = tf.sparse.to_dense(parsed[self.feature_key])
83+
padding_needed = self.max_sequence_length - tf.shape(seq)[0]
84+
seq = tf.pad(seq, [[0, padding_needed]])
85+
seq = tf.ensure_shape(seq, [self.max_sequence_length])
86+
seq = tf.cast(seq, tf.int32)
87+
88+
target = tf.cast(parsed[self.target_key], tf.int32)
89+
return seq, target
90+
91+
dataset = dataset.map(_parse_fn, num_parallel_calls=tf.data.AUTOTUNE)
92+
93+
if self.is_training:
94+
dataset = dataset.repeat()
95+
96+
return dataset.batch(self.batch_size, drop_remainder=True).prefetch(tf.data.AUTOTUNE)
97+
98+
class HSTUTask(jax_trainer.JaxTask):
99+
"""JaxTask for HSTU model"""
100+
101+
def __init__(
102+
self,
103+
model_config: HSTUModelConfig,
104+
train_data_factory: recml.Factory[tf.data.Dataset],
105+
eval_data_factory: recml.Factory[tf.data.Dataset],
106+
):
107+
self.config = model_config
108+
self.train_data_factory = train_data_factory
109+
self.eval_data_factory = eval_data_factory
110+
111+
def create_datasets(self) -> Tuple[tf.data.Dataset, tf.data.Dataset]:
112+
return self.train_data_factory.make(), self.eval_data_factory.make()
113+
114+
def _create_model(self) -> keras.Model:
115+
inputs = keras.Input(
116+
shape=(self.config.max_sequence_length,), dtype="int32", name="input_ids"
117+
)
118+
padding_mask = keras.ops.cast(keras.ops.not_equal(inputs, 0), "int32")
119+
120+
hstu_layer = hstu.HSTU(
121+
vocab_size=self.config.vocab_size,
122+
max_positions=self.config.max_sequence_length,
123+
model_dim=self.config.model_dim,
124+
num_heads=self.config.num_heads,
125+
num_layers=self.config.num_layers,
126+
dropout=self.config.dropout,
127+
)
128+
129+
logits = hstu_layer(inputs, padding_mask=padding_mask)
130+
131+
def get_last_token_logits(args):
132+
seq_logits, mask = args
133+
lengths = keras.ops.sum(keras.ops.cast(mask, "int32"), axis=1)
134+
last_indices = lengths - 1
135+
indices = keras.ops.expand_dims(keras.ops.expand_dims(last_indices, -1), -1)
136+
return keras.ops.squeeze(keras.ops.take_along_axis(seq_logits, indices, axis=1), axis=1)
137+
138+
output_logits = keras.layers.Lambda(get_last_token_logits)([logits, padding_mask])
139+
output_logits = keras.layers.Activation("linear", dtype="float32")(output_logits)
140+
141+
model = keras.Model(inputs=inputs, outputs=output_logits)
142+
return model
143+
144+
def create_state(self, batch, rng) -> jax_trainer.KerasState:
145+
inputs, _ = batch
146+
model = self._create_model()
147+
# Build the model to initialize variables
148+
model.build(inputs.shape)
149+
150+
optimizer = optax.adam(learning_rate=self.config.learning_rate)
151+
return jax_trainer.KerasState.create(model=model, tx=optimizer)
152+
153+
def train_step(
154+
self, batch, state: jax_trainer.KerasState, rng: jax.Array
155+
) -> Tuple[jax_trainer.KerasState, Mapping[str, clu_metrics.Metric]]:
156+
inputs, targets = batch
157+
158+
def loss_fn(tvars):
159+
logits, _ = state.model.stateless_call(tvars, state.ntvars, inputs)
160+
loss = optax.softmax_cross_entropy_with_integer_labels(
161+
logits, jnp.squeeze(targets)
162+
)
163+
return jnp.mean(loss), logits
164+
165+
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
166+
(loss, logits), grads = grad_fn(state.tvars)
167+
state = state.update(grads=grads)
168+
169+
metrics = self._compute_metrics(loss, logits, targets)
170+
return state, metrics
171+
172+
def eval_step(
173+
self, batch, state: jax_trainer.KerasState
174+
) -> Mapping[str, clu_metrics.Metric]:
175+
inputs, targets = batch
176+
logits, _ = state.model.stateless_call(state.tvars, state.ntvars, inputs)
177+
loss = optax.softmax_cross_entropy_with_integer_labels(
178+
logits, jnp.squeeze(targets)
179+
)
180+
loss = jnp.mean(loss)
181+
return self._compute_metrics(loss, logits, targets)
182+
183+
def _compute_metrics(self, loss, logits, targets):
184+
targets = jnp.squeeze(targets)
185+
metrics = {"loss": clu_metrics.Average.from_model_output(loss)}
186+
187+
# def get_acc(k):
188+
# _, top_k_indices = jax.nn.top_k(logits, k)
189+
# correct = jnp.sum(top_k_indices == targets[:, None], axis=-1)
190+
# return jnp.mean(correct)
191+
192+
# metrics["HR_10"] = clu_metrics.Average.from_model_output(get_acc(10))
193+
# metrics["HR_50"] = clu_metrics.Average.from_model_output(get_acc(50))
194+
# metrics["HR_200"] = clu_metrics.Average.from_model_output(get_acc(200))
195+
return metrics
196+
197+
def experiment() -> fdl.Config[recml.Experiment]:
198+
"""Defines the experiment structure using Fiddle configs"""
199+
200+
max_seq_len = 50
201+
batch_size = 128
202+
203+
model_cfg = fdl.Config(
204+
HSTUModelConfig,
205+
vocab_size=5_000_000,
206+
max_sequence_length=max_seq_len,
207+
model_dim=64,
208+
num_layers=4,
209+
dropout=0.5
210+
)
211+
212+
train_data = fdl.Config(
213+
TFRecordDataFactory,
214+
path="", # Placeholder
215+
batch_size=batch_size,
216+
max_sequence_length=max_seq_len,
217+
is_training=True
218+
)
219+
220+
eval_data = fdl.Config(
221+
TFRecordDataFactory,
222+
path="", # Placeholder
223+
batch_size=batch_size,
224+
max_sequence_length=max_seq_len,
225+
is_training=False
226+
)
227+
228+
task = fdl.Config(
229+
HSTUTask,
230+
model_config=model_cfg,
231+
train_data_factory=train_data,
232+
eval_data_factory=eval_data
233+
)
234+
235+
trainer = fdl.Config(
236+
jax_trainer.JaxTrainer,
237+
partitioner=fdl.Config(partitioning.DataParallelPartitioner),
238+
model_dir="/tmp/default_dir", # Placeholder
239+
train_steps=2000,
240+
steps_per_eval=10,
241+
steps_per_loop=10,
242+
)
243+
244+
return fdl.Config(recml.Experiment, task=task, trainer=trainer)
245+
246+
def main(_):
247+
# Ensure JAX uses the correct backend
248+
logging.info(f"JAX Backend: {jax.default_backend()}")
249+
250+
config = experiment()
251+
252+
logging.info(f"Setting Train Path to: {FLAGS.train_path}")
253+
config.task.train_data_factory.path = FLAGS.train_path
254+
255+
logging.info(f"Setting Eval Path to: {FLAGS.eval_path}")
256+
config.task.eval_data_factory.path = FLAGS.eval_path
257+
258+
config.task.model_config.vocab_size = FLAGS.vocab_size
259+
260+
logging.info(f"Setting Model Dir to: {FLAGS.model_dir}")
261+
config.trainer.model_dir = FLAGS.model_dir
262+
config.trainer.train_steps = FLAGS.train_steps
263+
264+
expt = fdl.build(config)
265+
266+
logging.info("Starting experiment execution...")
267+
core.run_experiment(expt, core.Experiment.Mode.TRAIN_AND_EVAL)
268+
269+
270+
if __name__ == "__main__":
271+
app.run(main)

0 commit comments

Comments
 (0)