|
| 1 | +"""HSTU Experiment Configuration using Fiddle and RecML with JaxTrainer""" |
| 2 | + |
| 3 | +import dataclasses |
| 4 | +from typing import Mapping, Tuple |
| 5 | +import sys |
| 6 | +import os |
| 7 | + |
| 8 | +os.environ["KERAS_BACKEND"] = "jax" |
| 9 | + |
| 10 | +import fiddle as fdl |
| 11 | +import jax |
| 12 | +import jax.numpy as jnp |
| 13 | +import keras |
| 14 | +import optax |
| 15 | +import tensorflow as tf |
| 16 | +import clu.metrics as clu_metrics |
| 17 | +from absl import app |
| 18 | +from absl import flags |
| 19 | +from absl import logging |
| 20 | + |
| 21 | +# Add the RecML folder to the system path |
| 22 | +sys.path.append(os.path.join(os.getcwd(), "../../../RecML")) |
| 23 | + |
| 24 | +# RecML Imports |
| 25 | +from recml.core.training import core |
| 26 | +from recml.core.training import jax_trainer |
| 27 | +from recml.core.training import partitioning |
| 28 | +from recml.layers.keras import hstu |
| 29 | +import recml |
| 30 | + |
| 31 | +# Define command-line flags |
| 32 | +FLAGS = flags.FLAGS |
| 33 | + |
| 34 | +flags.DEFINE_string("train_path", None, "Path (or pattern) to training data") |
| 35 | +flags.DEFINE_string("eval_path", None, "Path (or glob pattern) to evaluation data") |
| 36 | + |
| 37 | +flags.DEFINE_string("model_dir", "/tmp/hstu_model_jax", "Where to save the model") |
| 38 | +flags.DEFINE_integer("vocab_size", 5_000_000, "Vocabulary size") |
| 39 | +flags.DEFINE_integer("train_steps", 2000, "Total training steps") |
| 40 | + |
| 41 | +# Mark flags as required |
| 42 | +flags.mark_flag_as_required("train_path") |
| 43 | +flags.mark_flag_as_required("eval_path") |
| 44 | + |
| 45 | +@dataclasses.dataclass |
| 46 | +class HSTUModelConfig: |
| 47 | + """Configuration for the HSTU model architecture""" |
| 48 | + vocab_size: int = 5_000_000 |
| 49 | + max_sequence_length: int = 50 |
| 50 | + model_dim: int = 64 |
| 51 | + num_heads: int = 4 |
| 52 | + num_layers: int = 4 |
| 53 | + dropout: float = 0.5 |
| 54 | + learning_rate: float = 1e-3 |
| 55 | + |
| 56 | +class TFRecordDataFactory(recml.Factory[tf.data.Dataset]): |
| 57 | + """Reusable Data Factory for TFRecord datasets""" |
| 58 | + |
| 59 | + path: str |
| 60 | + batch_size: int |
| 61 | + max_sequence_length: int |
| 62 | + feature_key: str = "sequence" |
| 63 | + target_key: str = "target" |
| 64 | + is_training: bool = True |
| 65 | + |
| 66 | + def make(self) -> tf.data.Dataset: |
| 67 | + """Builds the tf.data.Dataset""" |
| 68 | + if not self.path: |
| 69 | + logging.warning("No path provided for dataset factory") |
| 70 | + return tf.data.Dataset.empty() |
| 71 | + |
| 72 | + dataset = tf.data.Dataset.list_files(self.path) |
| 73 | + dataset = tf.data.TFRecordDataset(dataset, num_parallel_reads=tf.data.AUTOTUNE) |
| 74 | + |
| 75 | + def _parse_fn(serialized_example): |
| 76 | + features = { |
| 77 | + self.feature_key: tf.io.VarLenFeature(tf.int64), |
| 78 | + self.target_key: tf.io.FixedLenFeature([1], tf.int64), |
| 79 | + } |
| 80 | + parsed = tf.io.parse_single_example(serialized_example, features) |
| 81 | + |
| 82 | + seq = tf.sparse.to_dense(parsed[self.feature_key]) |
| 83 | + padding_needed = self.max_sequence_length - tf.shape(seq)[0] |
| 84 | + seq = tf.pad(seq, [[0, padding_needed]]) |
| 85 | + seq = tf.ensure_shape(seq, [self.max_sequence_length]) |
| 86 | + seq = tf.cast(seq, tf.int32) |
| 87 | + |
| 88 | + target = tf.cast(parsed[self.target_key], tf.int32) |
| 89 | + return seq, target |
| 90 | + |
| 91 | + dataset = dataset.map(_parse_fn, num_parallel_calls=tf.data.AUTOTUNE) |
| 92 | + |
| 93 | + if self.is_training: |
| 94 | + dataset = dataset.repeat() |
| 95 | + |
| 96 | + return dataset.batch(self.batch_size, drop_remainder=True).prefetch(tf.data.AUTOTUNE) |
| 97 | + |
| 98 | +class HSTUTask(jax_trainer.JaxTask): |
| 99 | + """JaxTask for HSTU model""" |
| 100 | + |
| 101 | + def __init__( |
| 102 | + self, |
| 103 | + model_config: HSTUModelConfig, |
| 104 | + train_data_factory: recml.Factory[tf.data.Dataset], |
| 105 | + eval_data_factory: recml.Factory[tf.data.Dataset], |
| 106 | + ): |
| 107 | + self.config = model_config |
| 108 | + self.train_data_factory = train_data_factory |
| 109 | + self.eval_data_factory = eval_data_factory |
| 110 | + |
| 111 | + def create_datasets(self) -> Tuple[tf.data.Dataset, tf.data.Dataset]: |
| 112 | + return self.train_data_factory.make(), self.eval_data_factory.make() |
| 113 | + |
| 114 | + def _create_model(self) -> keras.Model: |
| 115 | + inputs = keras.Input( |
| 116 | + shape=(self.config.max_sequence_length,), dtype="int32", name="input_ids" |
| 117 | + ) |
| 118 | + padding_mask = keras.ops.cast(keras.ops.not_equal(inputs, 0), "int32") |
| 119 | + |
| 120 | + hstu_layer = hstu.HSTU( |
| 121 | + vocab_size=self.config.vocab_size, |
| 122 | + max_positions=self.config.max_sequence_length, |
| 123 | + model_dim=self.config.model_dim, |
| 124 | + num_heads=self.config.num_heads, |
| 125 | + num_layers=self.config.num_layers, |
| 126 | + dropout=self.config.dropout, |
| 127 | + ) |
| 128 | + |
| 129 | + logits = hstu_layer(inputs, padding_mask=padding_mask) |
| 130 | + |
| 131 | + def get_last_token_logits(args): |
| 132 | + seq_logits, mask = args |
| 133 | + lengths = keras.ops.sum(keras.ops.cast(mask, "int32"), axis=1) |
| 134 | + last_indices = lengths - 1 |
| 135 | + indices = keras.ops.expand_dims(keras.ops.expand_dims(last_indices, -1), -1) |
| 136 | + return keras.ops.squeeze(keras.ops.take_along_axis(seq_logits, indices, axis=1), axis=1) |
| 137 | + |
| 138 | + output_logits = keras.layers.Lambda(get_last_token_logits)([logits, padding_mask]) |
| 139 | + output_logits = keras.layers.Activation("linear", dtype="float32")(output_logits) |
| 140 | + |
| 141 | + model = keras.Model(inputs=inputs, outputs=output_logits) |
| 142 | + return model |
| 143 | + |
| 144 | + def create_state(self, batch, rng) -> jax_trainer.KerasState: |
| 145 | + inputs, _ = batch |
| 146 | + model = self._create_model() |
| 147 | + # Build the model to initialize variables |
| 148 | + model.build(inputs.shape) |
| 149 | + |
| 150 | + optimizer = optax.adam(learning_rate=self.config.learning_rate) |
| 151 | + return jax_trainer.KerasState.create(model=model, tx=optimizer) |
| 152 | + |
| 153 | + def train_step( |
| 154 | + self, batch, state: jax_trainer.KerasState, rng: jax.Array |
| 155 | + ) -> Tuple[jax_trainer.KerasState, Mapping[str, clu_metrics.Metric]]: |
| 156 | + inputs, targets = batch |
| 157 | + |
| 158 | + def loss_fn(tvars): |
| 159 | + logits, _ = state.model.stateless_call(tvars, state.ntvars, inputs) |
| 160 | + loss = optax.softmax_cross_entropy_with_integer_labels( |
| 161 | + logits, jnp.squeeze(targets) |
| 162 | + ) |
| 163 | + return jnp.mean(loss), logits |
| 164 | + |
| 165 | + grad_fn = jax.value_and_grad(loss_fn, has_aux=True) |
| 166 | + (loss, logits), grads = grad_fn(state.tvars) |
| 167 | + state = state.update(grads=grads) |
| 168 | + |
| 169 | + metrics = self._compute_metrics(loss, logits, targets) |
| 170 | + return state, metrics |
| 171 | + |
| 172 | + def eval_step( |
| 173 | + self, batch, state: jax_trainer.KerasState |
| 174 | + ) -> Mapping[str, clu_metrics.Metric]: |
| 175 | + inputs, targets = batch |
| 176 | + logits, _ = state.model.stateless_call(state.tvars, state.ntvars, inputs) |
| 177 | + loss = optax.softmax_cross_entropy_with_integer_labels( |
| 178 | + logits, jnp.squeeze(targets) |
| 179 | + ) |
| 180 | + loss = jnp.mean(loss) |
| 181 | + return self._compute_metrics(loss, logits, targets) |
| 182 | + |
| 183 | + def _compute_metrics(self, loss, logits, targets): |
| 184 | + targets = jnp.squeeze(targets) |
| 185 | + metrics = {"loss": clu_metrics.Average.from_model_output(loss)} |
| 186 | + |
| 187 | + # def get_acc(k): |
| 188 | + # _, top_k_indices = jax.nn.top_k(logits, k) |
| 189 | + # correct = jnp.sum(top_k_indices == targets[:, None], axis=-1) |
| 190 | + # return jnp.mean(correct) |
| 191 | + |
| 192 | + # metrics["HR_10"] = clu_metrics.Average.from_model_output(get_acc(10)) |
| 193 | + # metrics["HR_50"] = clu_metrics.Average.from_model_output(get_acc(50)) |
| 194 | + # metrics["HR_200"] = clu_metrics.Average.from_model_output(get_acc(200)) |
| 195 | + return metrics |
| 196 | + |
| 197 | +def experiment() -> fdl.Config[recml.Experiment]: |
| 198 | + """Defines the experiment structure using Fiddle configs""" |
| 199 | + |
| 200 | + max_seq_len = 50 |
| 201 | + batch_size = 128 |
| 202 | + |
| 203 | + model_cfg = fdl.Config( |
| 204 | + HSTUModelConfig, |
| 205 | + vocab_size=5_000_000, |
| 206 | + max_sequence_length=max_seq_len, |
| 207 | + model_dim=64, |
| 208 | + num_layers=4, |
| 209 | + dropout=0.5 |
| 210 | + ) |
| 211 | + |
| 212 | + train_data = fdl.Config( |
| 213 | + TFRecordDataFactory, |
| 214 | + path="", # Placeholder |
| 215 | + batch_size=batch_size, |
| 216 | + max_sequence_length=max_seq_len, |
| 217 | + is_training=True |
| 218 | + ) |
| 219 | + |
| 220 | + eval_data = fdl.Config( |
| 221 | + TFRecordDataFactory, |
| 222 | + path="", # Placeholder |
| 223 | + batch_size=batch_size, |
| 224 | + max_sequence_length=max_seq_len, |
| 225 | + is_training=False |
| 226 | + ) |
| 227 | + |
| 228 | + task = fdl.Config( |
| 229 | + HSTUTask, |
| 230 | + model_config=model_cfg, |
| 231 | + train_data_factory=train_data, |
| 232 | + eval_data_factory=eval_data |
| 233 | + ) |
| 234 | + |
| 235 | + trainer = fdl.Config( |
| 236 | + jax_trainer.JaxTrainer, |
| 237 | + partitioner=fdl.Config(partitioning.DataParallelPartitioner), |
| 238 | + model_dir="/tmp/default_dir", # Placeholder |
| 239 | + train_steps=2000, |
| 240 | + steps_per_eval=10, |
| 241 | + steps_per_loop=10, |
| 242 | + ) |
| 243 | + |
| 244 | + return fdl.Config(recml.Experiment, task=task, trainer=trainer) |
| 245 | + |
| 246 | +def main(_): |
| 247 | + # Ensure JAX uses the correct backend |
| 248 | + logging.info(f"JAX Backend: {jax.default_backend()}") |
| 249 | + |
| 250 | + config = experiment() |
| 251 | + |
| 252 | + logging.info(f"Setting Train Path to: {FLAGS.train_path}") |
| 253 | + config.task.train_data_factory.path = FLAGS.train_path |
| 254 | + |
| 255 | + logging.info(f"Setting Eval Path to: {FLAGS.eval_path}") |
| 256 | + config.task.eval_data_factory.path = FLAGS.eval_path |
| 257 | + |
| 258 | + config.task.model_config.vocab_size = FLAGS.vocab_size |
| 259 | + |
| 260 | + logging.info(f"Setting Model Dir to: {FLAGS.model_dir}") |
| 261 | + config.trainer.model_dir = FLAGS.model_dir |
| 262 | + config.trainer.train_steps = FLAGS.train_steps |
| 263 | + |
| 264 | + expt = fdl.build(config) |
| 265 | + |
| 266 | + logging.info("Starting experiment execution...") |
| 267 | + core.run_experiment(expt, core.Experiment.Mode.TRAIN_AND_EVAL) |
| 268 | + |
| 269 | + |
| 270 | +if __name__ == "__main__": |
| 271 | + app.run(main) |
0 commit comments