Skip to content

Commit 0eb3303

Browse files
committed
conversion done, fixing sharding issue
1 parent 6552f14 commit 0eb3303

19 files changed

Lines changed: 3487 additions & 91 deletions

src/maxdiffusion/checkpointing/checkpointing_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,13 @@ def load_state_if_possible(
213213
max_logging.log(f"restoring from this run's directory latest step {latest_step}")
214214
try:
215215
if not enable_single_replica_ckpt_restoring:
216-
item = {checkpoint_item: orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state)}
217-
return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item))
216+
# item = {checkpoint_item: orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state)}
217+
# return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item)) #currently changed to this
218+
if checkpoint_item == " ":
219+
return checkpoint_manager.restore(latest_step, args=ocp.args.StandardRestore(abstract_unboxed_pre_state))
220+
else:
221+
item = {checkpoint_item: orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state)}
222+
return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item)) #currently changed to this
218223

219224
def map_to_pspec(data):
220225
pspec = data.sharding.spec

src/maxdiffusion/configs/ltx_video.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,4 @@ per_device_batch_size: 1
4848
compile_topology_num_slices: -1
4949
quantization_local_shard_count: -1
5050
jit_initializers: True
51+
enable_single_replica_ckpt_restoring: False
Lines changed: 171 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,64 +1,211 @@
1+
from json import encoder
12
from absl import app
23
from typing import Sequence
34
import jax
5+
from flax import linen as nn
46
import json
7+
from flax.linen import partitioning as nn_partitioning
58
from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel
69
import os
710
import functools
811
import jax.numpy as jnp
9-
from maxdiffusion import FlaxAutoencoderKL, pyconfig, max_logging
12+
from maxdiffusion import pyconfig
1013
from maxdiffusion.max_utils import (
1114
create_device_mesh,
1215
setup_initial_state,
16+
get_memory_allocations,
1317
)
14-
from jax.sharding import Mesh, PositionalSharding, PartitionSpec as P
18+
from jax.sharding import Mesh, PartitionSpec as P
19+
import orbax.checkpoint as ocp
1520

1621

17-
def validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise_cond):
22+
def validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids):
1823
print("prompts_embeds.shape: ", prompt_embeds.shape, prompt_embeds.dtype)
1924
print("fractional_coords.shape: ", fractional_coords.shape, fractional_coords.dtype)
2025
print("latents.shape: ", latents.shape, latents.dtype)
2126
print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype)
27+
print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype)
28+
print("segment_ids.shape: ", segment_ids.shape, segment_ids.dtype)
29+
print("encoder_attention_segment_ids.shape: ", encoder_attention_segment_ids.shape, encoder_attention_segment_ids.dtype)
30+
31+
32+
def loop_body(
33+
step,
34+
args,
35+
transformer,
36+
fractional_cords,
37+
prompt_embeds,
38+
segment_ids,
39+
encoder_attention_segment_ids
40+
):
41+
latents, state, noise_cond = args
42+
noise_pred = transformer.apply(
43+
{"params": state.params},
44+
hidden_states=latents,
45+
indices_grid=fractional_cords,
46+
encoder_hidden_states=prompt_embeds,
47+
timestep=noise_cond,
48+
segment_ids=segment_ids,
49+
encoder_attention_segment_ids=encoder_attention_segment_ids
50+
)
51+
import pdb; pdb.set_trace()
52+
return noise_pred, state, noise_cond #need to make changes here? latents need to be changed based on noise_pred, but needs scheduler, return noise_pred for now
53+
54+
55+
56+
def run_inference(
57+
states, transformer, config, mesh, latents, fractional_cords, prompt_embeds, timestep, segment_ids, encoder_attention_segment_ids
58+
):
59+
transformer_state = states["transformer"]
60+
loop_body_p = functools.partial(
61+
loop_body,
62+
transformer=transformer,
63+
fractional_cords=fractional_cords,
64+
prompt_embeds=prompt_embeds,
65+
segment_ids=segment_ids,
66+
encoder_attention_segment_ids=encoder_attention_segment_ids
67+
)
68+
## TODO: add vae decode step
69+
## TODO: add loop
70+
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
71+
latents, transformer_state, _ = jax.lax.fori_loop(0, 1, loop_body_p, (latents, transformer_state, timestep))
72+
return latents
73+
2274

2375
def run(config):
2476
key = jax.random.PRNGKey(0)
2577

2678
devices_array = create_device_mesh(config)
2779
mesh = Mesh(devices_array, config.mesh_axes)
2880

29-
batch_size, text_tokens, num_tokens, features = 4, 256, 2048, 128
3081
base_dir = os.path.dirname(__file__)
3182

3283
##load in model config
3384
config_path = os.path.join(base_dir, "models/ltx_video/xora_v1.2-13B-balanced-128.json")
3485
with open(config_path, "r") as f:
3586
model_config = json.load(f)
87+
relative_ckpt_path = model_config["ckpt_path"]
3688

89+
ignored_keys = ["_class_name", "_diffusers_version", "_name_or_path", "causal_temporal_positioning", "in_channels", "ckpt_path"]
90+
in_channels = model_config["in_channels"]
91+
for name in ignored_keys:
92+
if name in model_config:
93+
del model_config[name]
3794

38-
transformer = Transformer3DModel(**model_config, dtype=jnp.bfloat16, gradient_checkpointing="matmul_without_batch")
39-
transformer_param_shapes = transformer.init_weights(key, batch_size, text_tokens, num_tokens, features, eval_only = False)
95+
96+
transformer = Transformer3DModel(**model_config, dtype=jnp.float32, gradient_checkpointing="matmul_without_batch", sharding_mesh=mesh)
97+
transformer_param_shapes = transformer.init_weights(in_channels, model_config['caption_channels'], eval_only = True) #use this to test!
4098

41-
key, split_key = jax.random.split(key)
4299
weights_init_fn = functools.partial(
43100
transformer.init_weights,
44-
split_key,
45-
batch_size,
46-
text_tokens,
47-
num_tokens,
48-
features,
49-
eval_only = False
101+
in_channels,
102+
model_config['caption_channels'],
103+
eval_only = True
50104
)
51105

52-
transformer_state, transformer_state_shardings = setup_initial_state(
106+
absolute_ckpt_path = os.path.abspath(relative_ckpt_path)
107+
108+
checkpoint_manager = ocp.CheckpointManager(absolute_ckpt_path)
109+
transformer_state, transformer_state_shardings = setup_initial_state(
53110
model=transformer,
54111
tx=None,
55112
config=config,
56113
mesh=mesh,
57114
weights_init_fn=weights_init_fn,
115+
checkpoint_manager=checkpoint_manager,
116+
checkpoint_item=" ",
58117
model_params=None,
59118
training=False,
60119
)
120+
121+
122+
123+
124+
transformer_state = jax.device_put(transformer_state, transformer_state_shardings)
125+
get_memory_allocations()
126+
127+
states = {}
128+
state_shardings = {}
129+
130+
state_shardings["transformer"] = transformer_state_shardings
131+
states["transformer"] = transformer_state
132+
133+
#create dummy inputs:
134+
example_inputs = {}
135+
batch_size, num_tokens = 4, 256
136+
input_shapes = {
137+
"latents": (batch_size, num_tokens, in_channels),
138+
"fractional_coords": (batch_size, 3, num_tokens),
139+
"prompt_embeds": (batch_size, 128, model_config["caption_channels"]),
140+
"timestep": (batch_size, 256), #TODO: add in the segment id stuff
141+
"segment_ids": (batch_size, 256),
142+
"encoder_attention_segment_ids": (batch_size, 128),
143+
}
144+
for name, shape in input_shapes.items():
145+
example_inputs[name] = jnp.ones(
146+
shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool
147+
)
148+
149+
data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding))
150+
latents = jax.device_put(example_inputs["latents"], data_sharding)
151+
prompt_embeds = jax.device_put(example_inputs["prompt_embeds"], data_sharding)
152+
fractional_coords = jax.device_put(example_inputs["fractional_coords"], data_sharding)
153+
noise_cond = jax.device_put(example_inputs["timestep"], data_sharding)
154+
segment_ids = jax.device_put(example_inputs["segment_ids"], data_sharding)
155+
encoder_attention_segment_ids = jax.device_put(example_inputs["encoder_attention_segment_ids"], data_sharding)
156+
157+
validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids)
158+
p_run_inference = jax.jit(
159+
functools.partial(
160+
run_inference,
161+
transformer=transformer,
162+
config=config,
163+
mesh=mesh,
164+
latents=latents,
165+
fractional_cords=fractional_coords,
166+
prompt_embeds=prompt_embeds,
167+
timestep = noise_cond,
168+
segment_ids=segment_ids,
169+
encoder_attention_segment_ids=encoder_attention_segment_ids
170+
),
171+
in_shardings=(state_shardings,),
172+
out_shardings=None,
173+
)
174+
noise_pred = p_run_inference(states).block_until_ready()
175+
print(noise_pred) #(4, 256, 128)
176+
177+
178+
179+
180+
181+
182+
183+
184+
185+
186+
187+
188+
189+
190+
61191

192+
193+
194+
195+
196+
197+
198+
199+
200+
201+
202+
203+
204+
205+
206+
207+
208+
62209

63210

64211
def main(argv: Sequence[str]) -> None:
@@ -71,3 +218,13 @@ def main(argv: Sequence[str]) -> None:
71218

72219

73220

221+
222+
###setup_initial_state, can optionally load from checkpoint
223+
224+
225+
226+
227+
228+
229+
230+
#end to end steps from ltx repo: pipeline_ltx_video.py

src/maxdiffusion/max_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,11 @@ def setup_initial_state(
402402
config.enable_single_replica_ckpt_restoring,
403403
)
404404
if state:
405-
state = state[checkpoint_item]
405+
###!Edited
406+
if checkpoint_item == " ":
407+
state = state
408+
else:
409+
state = state[checkpoint_item]
406410
if not state:
407411
max_logging.log(f"Could not find the item in orbax, creating state...")
408412
init_train_state_partial = functools.partial(
Lines changed: 13 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,22 @@
1+
2+
import argparse
3+
import json
4+
from typing import Any, Dict, Optional
15
import os
26
import jax
37
import jax.numpy as jnp
4-
import json
8+
import jax.lib.xla_extension
9+
import flax
10+
from flax.training import train_state
11+
import torch
12+
import optax
13+
import orbax.checkpoint as ocp
14+
from safetensors.torch import load_file
515

16+
from maxdiffusion.models.ltx_video.transformers_pytorch.transformer_pt import Transformer3DModel_PT
617

7-
from models.transformers.transformer3d import Transformer3DModel
8-
9-
# Load JSON config
1018
base_dir = os.path.dirname(__file__)
1119
config_path = os.path.join(base_dir, "xora_v1.2-13B-balanced-128.json")
1220
with open(config_path, "r") as f:
1321
model_config = json.load(f)
14-
15-
key = jax.random.PRNGKey(0)
16-
model = Transformer3DModel(**model_config, dtype=jnp.bfloat16, gradient_checkpointing="matmul_without_batch")
17-
18-
batch_size, text_tokens, num_tokens, features = 4, 256, 2048, 128
19-
prompt_embeds = jax.random.normal(key, shape=(batch_size, text_tokens, features), dtype=jnp.bfloat16)
20-
fractional_coords = jax.random.normal(key, shape=(batch_size, 3, num_tokens), dtype=jnp.bfloat16)
21-
latents = jax.random.normal(key, shape=(batch_size, num_tokens, features), dtype=jnp.bfloat16)
22-
noise_cond = jax.random.normal(key, shape=(batch_size, 1), dtype=jnp.bfloat16)
23-
24-
model_params = model.init(
25-
hidden_states=latents,
26-
indices_grid=fractional_coords,
27-
encoder_hidden_states=prompt_embeds,
28-
timestep=noise_cond,
29-
rngs={"params": key}
30-
)
31-
32-
output = model.apply(
33-
model_params,
34-
hidden_states=latents,
35-
indices_grid=fractional_coords,
36-
encoder_hidden_states=prompt_embeds,
37-
timestep=noise_cond,
38-
)
39-
40-
print("done!")
22+
transformer = Transformer3DModel_PT.from_config(model_config)

src/maxdiffusion/models/ltx_video/transformers/attention.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from functools import partial
2+
import functools
23
import math
34
from typing import Any, Dict, Optional, Tuple
45
from enum import Enum, auto
5-
66
import jax
77
import jax.nn as jnn
88
import jax.numpy as jnp
@@ -604,7 +604,8 @@ def __call__(
604604
block_sizes = self.default_block_sizes(q, k, dtype)
605605

606606
scale_factor = 1 / math.sqrt(q.shape[-1])
607-
607+
608+
608609
def partial_flash_attention(q, k, v, q_segment_ids, kv_segment_ids):
609610
s = (
610611
# flash attention expects segment ids to be float32
@@ -630,14 +631,27 @@ def partial_flash_attention(q, k, v, q_segment_ids, kv_segment_ids):
630631
raise ValueError(f"Expected mask with 2 dims, got {q_segment_ids.ndim}.")
631632
# Based on: ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim")
632633
# Computation of the spec based on the logical constraints can be found in logical_axes_to_spec.py.
634+
# qkvo_sharding_spec = jax.sharding.PartitionSpec(
635+
# ("data", "fsdp", "fsdp_transpose", "expert"),
636+
# ("tensor", "tensor_transpose", "sequence", "tensor_sequence"),
637+
# None,
638+
# None,
639+
# )
640+
# qkvo_sharding_spec = jax.sharding.PartitionSpec(
641+
# ("data", "fsdp", "fsdp_transpose", "expert"),
642+
# ("tensor", "tensor_transpose", "sequence", "tensor_sequence"),
643+
# None,
644+
# None,
645+
# )
633646
qkvo_sharding_spec = jax.sharding.PartitionSpec(
634-
("data", "fsdp", "fsdp_transpose", "expert"),
635-
("tensor", "tensor_transpose", "sequence", "tensor_sequence"),
647+
None,
648+
None,
636649
None,
637650
None,
638651
)
639-
# Based on: ("activation_kv_batch", "activation_length")
640-
qkv_segment_ids_spec = jax.sharding.PartitionSpec(("data", "fsdp", "fsdp_transpose", "expert"), "sequence")
652+
#Based on: ("activation_kv_batch", "activation_length")
653+
# qkv_segment_ids_spec = jax.sharding.PartitionSpec(("data", "fsdp", "fsdp_transpose", "expert"), "sequence")
654+
qkv_segment_ids_spec = jax.sharding.PartitionSpec(None, None)
641655
wrapped_flash_attention = shard_map(
642656
partial_flash_attention,
643657
mesh=sharding_mesh,

0 commit comments

Comments
 (0)