Skip to content

Commit 381dcdd

Browse files
Merge pull request #3570 from AI-Hypercomputer:chengnuojin-compile-flags
PiperOrigin-RevId: 895595405
2 parents 2a57a30 + 0ee7fa5 commit 381dcdd

4 files changed

Lines changed: 43 additions & 1 deletion

File tree

src/maxtext/configs/base.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,7 @@ hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu', 'gpu_multiprocess'
442442
# internal_compile allows bypassing open-source topology name mappings when using internal topologies directly via get_topology_desc.
443443
internal_compile: False
444444
internal_compile_num_devices: -1 # You must specify the number of devices when using internal_compile.
445+
compile_xla_flags: "" # Compiler options e.g. compile_xla_flags="--xla_tpu_num_sparse_cores_for_gather_offloading=1 --xla_tpu_scoped_vmem_limit_kib=65536"
445446

446447
# Parallelism
447448
shard_mode: "auto" # can be either auto or explicit

src/maxtext/configs/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -842,6 +842,7 @@ class LayoutAndSharding(BaseModel):
842842
shard_optimizer_over_data: bool = Field(False, description="Enable ZeRO-1 optimizer sharding over the data axis.")
843843
internal_compile: bool = Field(False, description="Use internal_compile to bypass open-source topology mappings.")
844844
internal_compile_num_devices: int = Field(-1, description="Number of devices when using internal_compile.")
845+
compile_xla_flags: str = Field("", description="Compiler options for compilation only.")
845846

846847

847848
class DcnParallelism(BaseModel):

src/maxtext/trainers/pre_train/train_compile.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,9 @@ def jit_and_compile(
144144
)
145145
maxtext_utils.maybe_dump_jaxpr(config, jitted, func_input_args)
146146
lowered = jitted.lower(*func_input_args, **func_input_kwargs)
147-
compiled = lowered.compile()
147+
# Import libtpu flags as compiler options. Defaults to empty dict if string is empty.
148+
compiler_options = max_utils.parse_libtpu_flags_to_dict(config.compile_xla_flags)
149+
compiled = lowered.compile(compiler_options=compiler_options)
148150
return compiled
149151

150152

src/maxtext/utils/max_utils.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from functools import partial
2222
import os
2323
import socket
24+
import re
2425
import subprocess
2526
import time
2627
from typing import Any
@@ -52,6 +53,43 @@
5253
# pylint: disable=too-many-positional-arguments
5354

5455

56+
def parse_libtpu_flags_to_dict(flags_str: str) -> dict:
57+
"""
58+
Parses a string of XLA flags into a dictionary of compilation options.
59+
This function is only for compilation usage.
60+
"""
61+
if not flags_str or not flags_str.strip():
62+
return {}
63+
64+
# Clean the string by removing line-continuation backslashes
65+
cleaned_str = flags_str.replace("\\", " ")
66+
67+
# Split by any whitespace (handles single spaces, multiple spaces, newlines)
68+
tokens = cleaned_str.split()
69+
70+
options_dict = {}
71+
72+
# Regex to strictly match '--key=value' for an isolated token
73+
# Key assumes alphanumeric + underscores. Value is anything after the '='.
74+
token_pattern = re.compile(r"^--([a-zA-Z0-9_]+)=(.+)$")
75+
76+
for token in tokens:
77+
match = token_pattern.match(token)
78+
if not match:
79+
# Throw an error immediately if any token fails the strict format
80+
raise ValueError(f"Invalid flag format detected: '{token}'. Expected format: '--key=value'")
81+
82+
key, value = match.groups()
83+
84+
# Optional: Catch duplicate flags
85+
if key in options_dict:
86+
raise ValueError(f"Duplicate flag detected: '--{key}'")
87+
88+
options_dict[key] = value
89+
90+
return options_dict
91+
92+
5593
def with_memory_kind(t, memory_kind):
5694
return jax.tree_util.tree_map(lambda x: x.with_memory_kind(kind=memory_kind), t)
5795

0 commit comments

Comments
 (0)