Skip to content

Commit 128ef01

Browse files
committed
Merge branch 'main' into video_training
2 parents 28dbe57 + 80771b1 commit 128ef01

54 files changed

Lines changed: 10875 additions & 21 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

Whitespace-only changes.

.github/workflows/UnitTests.yml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,11 @@ jobs:
3535
name: "TPU test (${{ matrix.tpu-type }})"
3636
runs-on: ["self-hosted", "tpu", "${{ matrix.tpu-type }}"]
3737
steps:
38-
- uses: actions/checkout@v3
38+
- uses: actions/checkout@v4
39+
- name: Set up Python 3.12
40+
uses: actions/setup-python@v5
41+
with:
42+
python-version: '3.12'
3943
- name: Install dependencies
4044
run: |
4145
pip install -e .
@@ -50,7 +54,7 @@ jobs:
5054
ruff check .
5155
- name: PyTest
5256
run: |
53-
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ python3 -m pytest -x
57+
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ python3 -m pytest -x --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py
5458
# add_pull_ready:
5559
# if: github.ref != 'refs/heads/main'
5660
# permissions:

docs/getting_started/first_run.md

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,19 @@ We recommend starting with a single host first and then moving to multihost.
88
Local development is a convenient way to run MaxDiffusion on a single host. It doesn't scale to
99
multiple hosts.
1010

11-
1. [Create and SSH to a single-host TPU (v4-8). ](https://cloud.google.com/tpu/docs/users-guide-tpu-vm#creating_a_cloud_tpu_vm_with_gcloud)
11+
1. [Create and SSH to a single-host TPU (v6-8). ](https://cloud.google.com/tpu/docs/users-guide-tpu-vm#creating_a_cloud_tpu_vm_with_gcloud)
12+
* You can find here [here](https://cloud.google.com/tpu/docs/regions-zones) the list of zones that support the v6(Trillium) TPUs
13+
* We recommend using the base VM image "v2-alpha-tpuv6e", which meets the version requirements: Ubuntu Version 22.04, Python 3.10 and Tensorflow >= 2.12.0
14+
1215
1. Clone MaxDiffusion in your TPU VM.
16+
```bash
17+
git clone https://github.com/AI-Hypercomputer/maxdiffusion.git
18+
cd maxdiffusion
19+
```
20+
1321
1. Within the root directory of the MaxDiffusion `git` repo, install dependencies by running:
1422
```bash
15-
If you are running on TPU:
1623
bash setup.sh MODE=stable DEVICE=tpu
17-
18-
If you are running on GPU:
19-
bash setup.sh MODE=stable DEVICE=gpu
2024
```
2125

2226
## Getting Starting: Multihost development

maxdiffusion_dependencies.Dockerfile

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
# Use Python 3.10-slim-bullseye as the base image
2-
FROM python:3.10-slim-bullseye
1+
# Use Python 3.12-slim-bullseye as the base image
2+
FROM python:3.12-slim-bullseye
33

44
# Environment variable for no-cache-dir and pip root user warning
55
ENV PIP_NO_CACHE_DIR=1
66
ENV PIP_ROOT_USER_ACTION=ignore
77

88
# Set environment variables for Google Cloud SDK and Python 3.10
9-
ENV PYTHON_VERSION=3.10
9+
ENV PYTHON_VERSION=3.12
1010
ENV CLOUD_SDK_VERSION=latest
1111

1212
# Set DEBIAN_FRONTEND to noninteractive to avoid frontend errors

requirements.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,10 @@ pytest==8.2.2
2323
tensorflow>=2.17.0
2424
tensorflow-datasets>=4.9.6
2525
ruff>=0.1.5,<=0.2
26+
git+https://github.com/Lightricks/LTX-Video
27+
git+https://github.com/zmelumian972/xla@torchax/jittable_module_callable#subdirectory=torchax
2628
opencv-python-headless==4.10.0.84
27-
orbax-checkpoint==0.10.3
29+
orbax-checkpoint
2830
tokenizers==0.21.0
2931
huggingface_hub>=0.30.2
3032
transformers==4.48.1

requirements_with_jax_ai_image.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ tensorflow>=2.17.0
2626
tensorflow-datasets>=4.9.6
2727
ruff>=0.1.5,<=0.2
2828
opencv-python-headless==4.10.0.84
29-
orbax-checkpoint==0.10.3
29+
orbax-checkpoint
3030
tokenizers==0.21.0
3131
huggingface_hub>=0.30.2
3232
transformers==4.48.1

setup.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,4 +112,4 @@ else
112112
fi
113113

114114
# Install maxdiffusion
115-
pip3 install -U . || echo "Failed to install maxdiffusion" >&2
115+
pip3 install -U . || echo "Failed to install maxdiffusion" >&2

src/maxdiffusion/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,7 @@
374374
_import_structure["models.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
375375
_import_structure["models.flux.transformers.transformer_flux_flax"] = ["FluxTransformer2DModel"]
376376
_import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"]
377+
_import_structure["models.ltx_video.transformers.transformer3d"] = ["Transformer3DModel"]
377378
_import_structure["pipelines"].extend(["FlaxDiffusionPipeline"])
378379
_import_structure["schedulers"].extend(
379380
[
@@ -453,6 +454,7 @@
453454
from .models.modeling_flax_utils import FlaxModelMixin
454455
from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
455456
from .models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel
457+
from .models.ltx_video.transformers.transformer3d import Transformer3DModel
456458
from .models.vae_flax import FlaxAutoencoderKL
457459
from .pipelines import FlaxDiffusionPipeline
458460
from .schedulers import (

src/maxdiffusion/checkpointing/checkpointing_utils.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from flax.training import train_state
2929
import orbax
3030
import orbax.checkpoint as ocp
31-
from orbax.checkpoint.logging import abstract_logger
31+
from orbax.checkpoint.logging import AbstractLogger
3232
from orbax.checkpoint.checkpoint_manager import CheckpointManager, CheckpointManagerOptions
3333

3434
STABLE_DIFFUSION_CHECKPOINT = "STABLE_DIFFUSION_CHECKPOINT"
@@ -43,7 +43,7 @@ def create_orbax_checkpoint_manager(
4343
checkpoint_type: str,
4444
dataset_type: str = "tf",
4545
use_async: bool = True,
46-
orbax_logger: Optional[abstract_logger.AbstractLogger] = None,
46+
orbax_logger: Optional[AbstractLogger] = None,
4747
):
4848
"""
4949
Returns specified Orbax (async or not) CheckpointManager or None if checkpointing is disabled.
@@ -213,8 +213,11 @@ def load_state_if_possible(
213213
max_logging.log(f"restoring from this run's directory latest step {latest_step}")
214214
try:
215215
if not enable_single_replica_ckpt_restoring:
216-
item = {checkpoint_item: orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state)}
217-
return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item))
216+
if checkpoint_item == "ltxvid_transformer":
217+
return checkpoint_manager.restore(latest_step, args=ocp.args.StandardRestore(abstract_unboxed_pre_state))
218+
else:
219+
item = {checkpoint_item: orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state)}
220+
return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item))
218221

219222
def map_to_pspec(data):
220223
pspec = data.sharding.spec
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
#hardware
2+
hardware: 'tpu'
3+
skip_jax_distributed_system: False
4+
5+
jax_cache_dir: ''
6+
weights_dtype: 'bfloat16'
7+
activations_dtype: 'bfloat16'
8+
9+
10+
run_name: ''
11+
output_dir: ''
12+
config_path: ''
13+
save_config_to_gcs: False
14+
15+
#Checkpoints
16+
text_encoder_model_name_or_path: "ariG23498/t5-v1-1-xxl-flax"
17+
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
18+
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
19+
frame_rate: 30
20+
max_sequence_length: 512
21+
sampler: "from_checkpoint"
22+
23+
# Generation parameters
24+
pipeline_type: multi-scale
25+
prompt: "A man in a dimly lit room talks on a vintage telephone, hangs up, and looks down with a sad expression. He holds the black rotary phone to his right ear with his right hand, his left hand holding a rocks glass with amber liquid. He wears a brown suit jacket over a white shirt, and a gold ring on his left ring finger. His short hair is neatly combed, and he has light skin with visible wrinkles around his eyes. The camera remains stationary, focused on his face and upper body. The room is dark, lit only by a warm light source off-screen to the left, casting shadows on the wall behind him. The scene appears to be from a movie. "
26+
#negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
27+
height: 512
28+
width: 512
29+
num_frames: 88
30+
flow_shift: 5.0
31+
downscale_factor: 0.6666666
32+
spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.7.safetensors"
33+
prompt_enhancement_words_threshold: 120
34+
stg_mode: "attention_values"
35+
decode_timestep: 0.05
36+
decode_noise_scale: 0.025
37+
seed: 10
38+
39+
40+
first_pass:
41+
guidance_scale: [1, 1, 6, 8, 6, 1, 1]
42+
stg_scale: [0, 0, 4, 4, 4, 2, 1]
43+
rescaling_scale: [1, 1, 0.5, 0.5, 1, 1, 1]
44+
guidance_timesteps: [1.0, 0.996, 0.9933, 0.9850, 0.9767, 0.9008, 0.6180]
45+
skip_block_list: [[], [11, 25, 35, 39], [22, 35, 39], [28], [28], [28], [28]]
46+
num_inference_steps: 30
47+
skip_final_inference_steps: 3
48+
skip_initial_inference_steps: 0
49+
cfg_star_rescale: True
50+
51+
second_pass:
52+
guidance_scale: [1]
53+
stg_scale: [1]
54+
rescaling_scale: [1]
55+
guidance_timesteps: [1.0]
56+
skip_block_list: [27]
57+
num_inference_steps: 30
58+
skip_initial_inference_steps: 17
59+
skip_final_inference_steps: 0
60+
cfg_star_rescale: True
61+
62+
#parallelism
63+
mesh_axes: ['data', 'fsdp', 'tensor']
64+
logical_axis_rules: [
65+
['batch', 'data'],
66+
['activation_heads', 'fsdp'],
67+
['activation_batch', 'data'],
68+
['activation_kv', 'tensor'],
69+
['mlp','tensor'],
70+
['embed','fsdp'],
71+
['heads', 'tensor'],
72+
['norm', 'fsdp'],
73+
['conv_batch', ['data','fsdp']],
74+
['out_channels', 'tensor'],
75+
['conv_out', 'fsdp'],
76+
['conv_in', 'fsdp']
77+
]
78+
data_sharding: [['data', 'fsdp', 'tensor']]
79+
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
80+
dcn_fsdp_parallelism: -1
81+
dcn_tensor_parallelism: 1
82+
ici_data_parallelism: 1
83+
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
84+
ici_tensor_parallelism: 1
85+
86+
allow_split_physical_axes: False
87+
learning_rate_schedule_steps: -1
88+
max_train_steps: 500
89+
pretrained_model_name_or_path: ''
90+
unet_checkpoint: ''
91+
dataset_name: 'diffusers/pokemon-gpt4-captions'
92+
train_split: 'train'
93+
dataset_type: 'tf'
94+
cache_latents_text_encoder_outputs: True
95+
per_device_batch_size: 1
96+
compile_topology_num_slices: -1
97+
quantization_local_shard_count: -1
98+
jit_initializers: True
99+
enable_single_replica_ckpt_restoring: False

0 commit comments

Comments
 (0)