Skip to content

Commit 661c4ac

Browse files
committed
Merge branch 'main' into sanbao/bugs
2 parents 55b309c + 955bd86 commit 661c4ac

15 files changed

Lines changed: 722 additions & 160 deletions

File tree

.github/workflows/UploadDockerImages.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ jobs:
3232
- name: build maxdiffusion jax ai image
3333
run: |
3434
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_stable_stack MODE=jax_ai_image PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:latest
35+
- name: build maxdiffusion w/ nightly jax ai image
36+
run: |
37+
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_stable_stack_nightly MODE=jax_ai_image PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_stable_stack BASEIMAGE=us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/tpu/jax_nightly:latest
3538
- name: build maxdiffusion jax nightly image
3639
run: |
3740
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_nightly MODE=nightly PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_nightly

README.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
[![Unit Tests](https://github.com/google/maxtext/actions/workflows/UnitTests.yml/badge.svg)](https://github.com/google/maxdiffusion/actions/workflows/UnitTests.yml)
1818

1919
# What's new?
20+
- **`2025/8/14`**: LTX-Video img2vid generation is now supported.
2021
- **`2025/7/29`**: LTX-Video text2vid generation is now supported.
2122
- **`2025/04/17`**: Flux Finetuning.
2223
- **`2025/02/12`**: Flux LoRA for inference.
@@ -42,7 +43,7 @@ MaxDiffusion supports
4243
* Load Multiple LoRA (SDXL inference).
4344
* ControlNet inference (Stable Diffusion 1.4 & SDXL).
4445
* Dreambooth training support for Stable Diffusion 1.x,2.x.
45-
* LTX-Video text2vid (inference).
46+
* LTX-Video text2vid, img2vid (inference).
4647

4748

4849
# Table of Contents
@@ -177,13 +178,14 @@ To generate images, run the following command:
177178
## LTX-Video
178179
- In the folder src/maxdiffusion/models/ltx_video/utils, run:
179180
```bash
180-
python convert_torch_weights_to_jax.py --ckpt_path [LOCAL DIRECTORY FOR WEIGHTS] --transformer_config_path ../xora_v1.2-13B-balanced-128.json
181+
python convert_torch_weights_to_jax.py --ckpt_path [LOCAL DIRECTORY FOR WEIGHTS] --transformer_config_path ../ltxv-13B.json
181182
```
182183
- In the repo folder, run:
183184
```bash
184-
python src/maxdiffusion/generate_ltx_video.py src/maxdiffusion/configs/ltx_video.yml output_dir="[SAME DIRECTORY]" config_path="src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json"
185+
python src/maxdiffusion/generate_ltx_video.py src/maxdiffusion/configs/ltx_video.yml output_dir="[SAME DIRECTORY]" config_path="src/maxdiffusion/models/ltx_video/ltxv-13B.json"
185186
```
186-
- Other generation parameters can be set in ltx_video.yml file.
187+
- Img2video Generation:
188+
Add conditioning image path as conditioning_media_paths in the form of ["IMAGE_PATH"] along with other generation parameters in the ltx_video.yml file. Then follow same instruction as above.
187189
## Flux
188190

189191
First make sure you have permissions to access the Flux repos in Huggingface.

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ weights_dtype: 'bfloat16'
4040
# This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype)
4141
activations_dtype: 'bfloat16'
4242

43+
# Replicates vae across devices instead of using the model's sharding annotations for sharding.
44+
replicate_vae: False
45+
4346
# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision
4447
# Options are "DEFAULT", "HIGH", "HIGHEST"
4548
# fp32 activations and fp32 weights with HIGHEST will provide the best precision
@@ -291,3 +294,6 @@ use_qwix_quantization: False # Whether to use qwix for quantization. If set to T
291294
# Quantization calibration method used for weights and activations. Supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#L70-L80
292295
quantization_calibration_method: "absmax"
293296

297+
# Eval model on per eval_every steps. -1 means don't eval.
298+
eval_every: -1
299+
eval_data_dir: ""

src/maxdiffusion/configs/ltx_video.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ sampler: "from_checkpoint"
2222

2323
# Generation parameters
2424
pipeline_type: multi-scale
25-
prompt: "A man in a dimly lit room talks on a vintage telephone, hangs up, and looks down with a sad expression. He holds the black rotary phone to his right ear with his right hand, his left hand holding a rocks glass with amber liquid. He wears a brown suit jacket over a white shirt, and a gold ring on his left ring finger. His short hair is neatly combed, and he has light skin with visible wrinkles around his eyes. The camera remains stationary, focused on his face and upper body. The room is dark, lit only by a warm light source off-screen to the left, casting shadows on the wall behind him. The scene appears to be from a movie. "
25+
prompt: "A man in a dimly lit room talks on a vintage telephone, hangs up, and looks down with a sad expression. He holds the black rotary phone to his right ear with his right hand, his left hand holding a rocks glass with amber liquid. He wears a brown suit jacket over a white shirt, and a gold ring on his left ring finger. His short hair is neatly combed, and he has light skin with visible wrinkles around his eyes. The camera remains stationary, focused on his face and upper body. The room is dark, lit only by a warm light source off-screen to the left, casting shadows on the wall behind him. The scene appears to be from a movie."
2626
#negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
2727
height: 512
2828
width: 512
@@ -35,6 +35,8 @@ stg_mode: "attention_values"
3535
decode_timestep: 0.05
3636
decode_noise_scale: 0.025
3737
seed: 10
38+
conditioning_media_paths: None #["IMAGE_PATH"]
39+
conditioning_start_frames: [0]
3840

3941

4042
first_pass:

src/maxdiffusion/generate_ltx_video.py

Lines changed: 112 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,19 @@
1616

1717
import numpy as np
1818
from absl import app
19-
from typing import Sequence
19+
from typing import Sequence, List, Optional, Union
2020
from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXVideoPipeline
21-
from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXMultiScalePipeline
21+
from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXMultiScalePipeline, ConditioningItem
22+
import maxdiffusion.pipelines.ltx_video.crf_compressor as crf_compressor
2223
from maxdiffusion import pyconfig, max_logging
24+
import torchvision.transforms.functional as TVF
2325
import imageio
2426
from datetime import datetime
2527
import os
2628
import time
2729
from pathlib import Path
30+
from PIL import Image
31+
import torch
2832

2933

3034
def calculate_padding(
@@ -44,6 +48,79 @@ def calculate_padding(
4448
return padding
4549

4650

51+
def load_image_to_tensor_with_resize_and_crop(
52+
image_input: Union[str, Image.Image],
53+
target_height: int = 512,
54+
target_width: int = 768,
55+
just_crop: bool = False,
56+
) -> torch.Tensor:
57+
"""Load and process an image into a tensor.
58+
59+
Args:
60+
image_input: Either a file path (str) or a PIL Image object
61+
target_height: Desired height of output tensor
62+
target_width: Desired width of output tensor
63+
just_crop: If True, only crop the image to the target size without resizing
64+
"""
65+
if isinstance(image_input, str):
66+
image = Image.open(image_input).convert("RGB")
67+
elif isinstance(image_input, Image.Image):
68+
image = image_input
69+
else:
70+
raise ValueError("image_input must be either a file path or a PIL Image object")
71+
72+
input_width, input_height = image.size
73+
aspect_ratio_target = target_width / target_height
74+
aspect_ratio_frame = input_width / input_height
75+
if aspect_ratio_frame > aspect_ratio_target:
76+
new_width = int(input_height * aspect_ratio_target)
77+
new_height = input_height
78+
x_start = (input_width - new_width) // 2
79+
y_start = 0
80+
else:
81+
new_width = input_width
82+
new_height = int(input_width / aspect_ratio_target)
83+
x_start = 0
84+
y_start = (input_height - new_height) // 2
85+
86+
image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height))
87+
if not just_crop:
88+
image = image.resize((target_width, target_height))
89+
90+
frame_tensor = TVF.to_tensor(image) # PIL -> tensor (C, H, W), [0,1]
91+
frame_tensor = TVF.gaussian_blur(frame_tensor, kernel_size=3, sigma=1.0)
92+
frame_tensor_hwc = frame_tensor.permute(1, 2, 0) # (C, H, W) -> (H, W, C)
93+
frame_tensor_hwc = crf_compressor.compress(frame_tensor_hwc)
94+
frame_tensor = frame_tensor_hwc.permute(2, 0, 1) * 255.0 # (H, W, C) -> (C, H, W)
95+
frame_tensor = (frame_tensor / 127.5) - 1.0
96+
# Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width)
97+
return frame_tensor.unsqueeze(0).unsqueeze(2)
98+
99+
100+
def prepare_conditioning(
101+
conditioning_media_paths: List[str],
102+
conditioning_strengths: List[float],
103+
conditioning_start_frames: List[int],
104+
height: int,
105+
width: int,
106+
padding: tuple[int, int, int, int],
107+
) -> Optional[List[ConditioningItem]]:
108+
"""Prepare conditioning items based on input media paths and their parameters."""
109+
conditioning_items = []
110+
for path, strength, start_frame in zip(conditioning_media_paths, conditioning_strengths, conditioning_start_frames):
111+
num_input_frames = 1
112+
media_tensor = load_media_file(
113+
media_path=path,
114+
height=height,
115+
width=width,
116+
max_frames=num_input_frames,
117+
padding=padding,
118+
just_crop=True,
119+
)
120+
conditioning_items.append(ConditioningItem(media_tensor, start_frame, strength))
121+
return conditioning_items
122+
123+
47124
def convert_prompt_to_filename(text: str, max_len: int = 20) -> str:
48125
# Remove non-letters and convert to lowercase
49126
clean_text = "".join(char.lower() for char in text if char.isalpha() or char.isspace())
@@ -68,6 +145,19 @@ def convert_prompt_to_filename(text: str, max_len: int = 20) -> str:
68145
return "-".join(result)
69146

70147

148+
def load_media_file(
149+
media_path: str,
150+
height: int,
151+
width: int,
152+
max_frames: int,
153+
padding: tuple[int, int, int, int],
154+
just_crop: bool = False,
155+
) -> torch.Tensor:
156+
media_tensor = load_image_to_tensor_with_resize_and_crop(media_path, height, width, just_crop=just_crop)
157+
media_tensor = torch.nn.functional.pad(media_tensor, padding)
158+
return media_tensor
159+
160+
71161
def get_unique_filename(
72162
base: str,
73163
ext: str,
@@ -97,6 +187,25 @@ def run(config):
97187
pipeline = LTXVideoPipeline.from_pretrained(config, enhance_prompt=enhance_prompt)
98188
if config.pipeline_type == "multi-scale":
99189
pipeline = LTXMultiScalePipeline(pipeline)
190+
conditioning_media_paths = config.conditioning_media_paths if isinstance(config.conditioning_media_paths, List) else None
191+
conditioning_start_frames = config.conditioning_start_frames
192+
conditioning_strengths = None
193+
if conditioning_media_paths:
194+
if not conditioning_strengths:
195+
conditioning_strengths = [1.0] * len(conditioning_media_paths)
196+
conditioning_items = (
197+
prepare_conditioning(
198+
conditioning_media_paths=conditioning_media_paths,
199+
conditioning_strengths=conditioning_strengths,
200+
conditioning_start_frames=conditioning_start_frames,
201+
height=config.height,
202+
width=config.width,
203+
padding=padding,
204+
)
205+
if conditioning_media_paths
206+
else None
207+
)
208+
100209
s0 = time.perf_counter()
101210
images = pipeline(
102211
height=height_padded,
@@ -106,6 +215,7 @@ def run(config):
106215
output_type="pt",
107216
config=config,
108217
enhance_prompt=enhance_prompt,
218+
conditioning_items=conditioning_items,
109219
seed=config.seed,
110220
)
111221
max_logging.log(f"Compile time: {time.perf_counter() - s0:.1f}s.")

src/maxdiffusion/input_pipeline/_tfds_data_processing.py

Lines changed: 63 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import tensorflow.experimental.numpy as tnp
2020
from datasets import load_dataset, load_from_disk
2121
import jax
22-
from maxdiffusion import multihost_dataloading
22+
from maxdiffusion import multihost_dataloading, max_logging
2323

2424
AUTOTUNE = tf.data.AUTOTUNE
2525

@@ -78,92 +78,91 @@ def make_tf_iterator(
7878
train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh)
7979
return train_iter
8080

81-
82-
def make_cached_tfrecord_iterator(
83-
config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn
84-
):
85-
"""
86-
New iterator for TFRecords that contain the full 4 pre-computed latents and embeddings:
87-
latents, input_ids, prompt_embeds, and text_embeds.
88-
"""
89-
90-
def _parse_tfrecord_fn(example):
91-
return tf.io.parse_single_example(example, feature_description)
92-
93-
# This pipeline reads the sharded files and applies the parsing and preparation.
94-
filenames = tf.io.gfile.glob(os.path.join(config.train_data_dir, "*"))
95-
96-
train_ds = (
97-
tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)
98-
.shard(num_shards=dataloading_host_count, index=dataloading_host_index)
99-
.map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE)
100-
.map(prepare_sample_fn, num_parallel_calls=AUTOTUNE)
101-
.shuffle(global_batch_size * 10)
102-
.batch(global_batch_size // dataloading_host_count, drop_remainder=True)
103-
.repeat(-1)
104-
.prefetch(AUTOTUNE)
105-
)
106-
107-
# This wraps the tf.data.Dataset for use in the multi-host JAX environment.
108-
train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh)
109-
return train_iter
110-
111-
11281
# TODO - https://github.com/google/array_record/blob/main/beam/examples/example_gcs_conversion.py
113-
def make_tfrecord_iterator(
114-
config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn
82+
def _make_tfrecord_iterator(
83+
config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description_fn, prepare_sample_fn, dataset_path, is_training: bool
11584
):
116-
"""Iterator for TFRecord format. For Laion dataset,
117-
check out preparation script
118-
maxdiffusion/pedagogical_examples/to_tfrecords.py
119-
"""
12085
# set load_tfrecord_cached to True in config to use pre-processed tfrecord dataset.
12186
# pedagogical_examples/dataset_tf_cache_to_tfrecord.py to convert tf preprocessed dataset to tfrecord.
12287
# Dataset cache in github runner test doesn't contain all the features since its shared, Use the default tfrecord iterator.
88+
# if is_training is True, loads the training dataset. If False, loads the evaluation dataset.
12389

12490
# checks that the dataset path is valid. In case of gcs, the existance of the dir is not checked.
12591
is_dataset_dir_valid = "gs://" in config.dataset_save_location or os.path.isdir(config.dataset_save_location)
12692

127-
if (
128-
config.cache_latents_text_encoder_outputs
129-
and is_dataset_dir_valid
130-
and "load_tfrecord_cached" in config.get_keys()
131-
and config.load_tfrecord_cached
132-
):
133-
return make_cached_tfrecord_iterator(
134-
config,
135-
dataloading_host_index,
136-
dataloading_host_count,
137-
mesh,
138-
global_batch_size,
139-
feature_description,
140-
prepare_sample_fn,
141-
)
93+
# Determine whether to use the "cached" dataset, which requires externally
94+
# provided parsing functions, or the default one with its internal parsing logic.
95+
make_cached_tfrecord_iterator = (
96+
config.cache_latents_text_encoder_outputs
97+
and is_dataset_dir_valid
98+
and "load_tfrecord_cached" in config.get_keys()
99+
and config.load_tfrecord_cached
100+
)
142101

143102
feature_description = {
144103
"moments": tf.io.FixedLenFeature([], tf.string),
145104
"clip_embeddings": tf.io.FixedLenFeature([], tf.string),
146105
}
147106

107+
used_feature_description = feature_description_fn if make_cached_tfrecord_iterator else feature_description
108+
148109
def _parse_tfrecord_fn(example):
149-
return tf.io.parse_single_example(example, feature_description)
110+
return tf.io.parse_single_example(example, used_feature_description)
150111

151112
def prepare_sample(features):
152113
moments = tf.io.parse_tensor(tnp.asarray(features["moments"]), out_type=tf.float32)
153114
clip_embeddings = tf.io.parse_tensor(tnp.asarray(features["clip_embeddings"]), out_type=tf.float32)
154115
return {"pixel_values": moments, "input_ids": clip_embeddings}
155116

156-
filenames = tf.io.gfile.glob(os.path.join(config.train_data_dir, "*"))
157-
train_ds = (
158-
tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)
159-
.shard(num_shards=dataloading_host_count, index=dataloading_host_index)
160-
.map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE)
161-
.map(prepare_sample, num_parallel_calls=AUTOTUNE)
162-
.shuffle(global_batch_size * 10)
117+
filenames = tf.io.gfile.glob(os.path.join(dataset_path, "*"))
118+
ds = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)
119+
120+
# --- PADDING LOGIC FOR EVALUATION ---
121+
if not is_training:
122+
num_eval_samples = 0
123+
for _ in ds:
124+
num_eval_samples += 1
125+
126+
remainder = num_eval_samples % global_batch_size
127+
if remainder != 0:
128+
num_to_pad = global_batch_size - remainder
129+
# Create a dataset of padding samples from the beginning
130+
padding_ds = ds.take(num_to_pad)
131+
# Add the padding samples to the end
132+
ds = ds.concatenate(padding_ds)
133+
max_logging.log(f"Padded evaluation dataset with {num_to_pad} samples.")
134+
135+
used_prepare_sample = prepare_sample_fn if make_cached_tfrecord_iterator else prepare_sample
136+
ds = (
137+
ds.shard(num_shards=dataloading_host_count, index=dataloading_host_index)
138+
.map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE)
139+
.map(used_prepare_sample, num_parallel_calls=AUTOTUNE)
140+
)
141+
if is_training:
142+
ds = (
143+
ds.shuffle(global_batch_size * 10)
163144
.batch(global_batch_size // dataloading_host_count, drop_remainder=True)
164145
.repeat(-1)
165146
.prefetch(AUTOTUNE)
166-
)
147+
)
148+
# For Evaluation
149+
else:
150+
ds = (
151+
ds.batch(global_batch_size // dataloading_host_count, drop_remainder=False)
152+
.prefetch(AUTOTUNE)
153+
)
167154

168-
train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh)
169-
return train_iter
155+
iter = multihost_dataloading.MultiHostDataLoadIterator(ds, mesh)
156+
return iter
157+
158+
def make_tfrecord_iterator(
159+
config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn, is_training
160+
):
161+
"""Iterator for TFRecord format. For Laion dataset,
162+
check out preparation script
163+
maxdiffusion/pedagogical_examples/to_tfrecords.py
164+
"""
165+
# Currently only support evaluation on tfrecord. To avoid influencing previous reference, judge whether is training dataset.
166+
# TODO: refactor to support evaluation on all dataset format.
167+
dataset_path = config.train_data_dir if is_training else config.eval_data_dir
168+
return _make_tfrecord_iterator(config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn, dataset_path, is_training)

0 commit comments

Comments
 (0)