Skip to content

Commit dd8df22

Browse files
Merge pull request #3550 from AI-Hypercomputer:aireen/fix_mm_sft
PiperOrigin-RevId: 893252738
2 parents d370f95 + d104c42 commit dd8df22

5 files changed

Lines changed: 13 additions & 14 deletions

File tree

src/maxtext/input_pipeline/hf_data_processing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,15 +146,15 @@ def vision_sft_preprocessing_pipeline(
146146
query_column=text_columns[0],
147147
response_column=text_columns[1],
148148
max_target_length=config.max_target_length,
149-
unk_id=pad_id,
149+
pad_id=pad_id,
150150
)
151151
)
152152
# TODO(aireenmei, hengtaoguo): support packing
153153
operations.append(
154154
input_pipeline_utils.PadOrTrimToMaxLength(
155155
config.max_target_length,
156156
pad_id,
157-
model_name=config.model_name,
157+
config=config,
158158
max_num_images_per_example=config.max_num_images_per_example,
159159
)
160160
)

src/maxtext/input_pipeline/input_pipeline_utils.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -313,15 +313,15 @@ def map(self, element):
313313
class SFTPromptMaskingVision(grain.MapTransform):
314314
"""SFT prompt masking for multimodal"""
315315

316-
def __init__(self, query_column, response_column, max_target_length, unk_id):
316+
def __init__(self, query_column, response_column, max_target_length, pad_id):
317317
self.query_column = query_column
318318
self.response_column = response_column
319319
self.max_target_length = max_target_length
320-
self.unk_id = unk_id
320+
self.pad_id = pad_id
321321

322322
def map(self, element):
323323
inputs = np.concatenate((element[self.query_column], element[self.response_column]))
324-
targets = np.concatenate((np.asarray([self.unk_id] * len(element[self.query_column])), element[self.response_column]))
324+
targets = np.concatenate((np.asarray([self.pad_id] * len(element[self.query_column])), element[self.response_column]))
325325
return {
326326
"inputs": np.asarray(inputs[: self.max_target_length], dtype=np.int32),
327327
"targets": np.asarray(targets[: self.max_target_length], dtype=np.int32),
@@ -559,13 +559,13 @@ def __init__(
559559
self,
560560
max_length: int,
561561
pad_id: int = 0,
562-
model_name: str | None = None,
562+
config=None,
563563
add_true_length: bool = False,
564564
max_num_images_per_example: int = -1,
565565
):
566566
self.max_length = max_length
567567
self.pad_id = pad_id
568-
self.model_name = model_name
568+
self.config = config
569569
self.add_true_length = add_true_length
570570
self.max_num_images_per_example = max_num_images_per_example
571571

@@ -614,7 +614,7 @@ def _pad_image_and_mask(self, preprocessed_image: mm_utils.PreprocessorOutput) -
614614
raise ValueError("Input preprocessed_image must have pixel_values to pad images.")
615615

616616
# Determine the maximum number of images/masks allowed.
617-
image_offsets = mm_processor.get_image_offsets(self.model_name, preprocessed_image)
617+
image_offsets = mm_processor.get_image_offsets(self.config, preprocessed_image)
618618
single_image_offset = image_offsets // preprocessed_image.pixel_values.shape[0]
619619

620620
# Reserve space for at least one text token.
@@ -680,7 +680,7 @@ def map(
680680

681681
for key, _ in element.items():
682682
if key == "images":
683-
if self.model_name is None:
683+
if self.config.model_name is None:
684684
raise ValueError("model_name must be provided when padding images")
685685

686686
element["images"] = self._pad_image_and_mask(element["images"])

src/maxtext/layers/moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -598,8 +598,8 @@ def apply_ffn_activation(self, layer_w0, layer_w1):
598598
"""Applies FFN activation function."""
599599
with jax.named_scope("ffn_act"):
600600
if self.config.decoder_block == ctypes.DecoderBlockType.GPT_OSS:
601-
layer_w0 = jnp.clip(layer_w0, a_min=None, a_max=self.config.mlp_activations_limit)
602-
layer_w1 = jnp.clip(layer_w1, a_min=-self.config.mlp_activations_limit, a_max=self.config.mlp_activations_limit)
601+
layer_w0 = jnp.clip(layer_w0, min=None, max=self.config.mlp_activations_limit)
602+
layer_w1 = jnp.clip(layer_w1, min=-self.config.mlp_activations_limit, max=self.config.mlp_activations_limit)
603603
layer_act = self.activation_fn(layer_w0 * 1.702)
604604
glu = jnp.multiply(layer_w0, layer_act)
605605
intermediate_layer = jnp.multiply(glu, (layer_w1 + 1))

src/maxtext/multimodal/processor_gemma3.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,9 @@ def preprocess_mm_data_gemma3(images):
7777
images_out.append(img)
7878

7979
processor_output = Gemma3PreprocessorOutput(
80-
num_images=len(images),
80+
num_images=len(images_in),
8181
pixel_values=np.stack(images_out, axis=0).astype(np.float32), # (N, H, W, C)
8282
)
83-
processor_output.num_images = len(images)
8483
return processor_output
8584

8685

src/maxtext/utils/maxtext_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -827,7 +827,7 @@ def calculate_tflops_training_per_device(config, log=True):
827827
if config.use_multimodal:
828828
# Add vision layers TFLOPs for multimodal models
829829
mm_total_tflops, mm_learnable_weight_tflops, mm_attention_tflops = calculate_vision_encoder_tflops(config)
830-
if log:
830+
if log and mm_total_tflops > 0:
831831
print(
832832
f"{config.model_name} vision layers per train step:\n",
833833
f"Total TFLOPs: {mm_total_tflops:.2f} \n",

0 commit comments

Comments
 (0)