Skip to content

Wan Animate Pipeline#367

Open
csgoogle wants to merge 3 commits intomainfrom
sagarchapara/wananimate-pipeline
Open

Wan Animate Pipeline#367
csgoogle wants to merge 3 commits intomainfrom
sagarchapara/wananimate-pipeline

Conversation

@csgoogle
Copy link
Copy Markdown
Collaborator

@csgoogle csgoogle commented Mar 28, 2026

Wan Animate Pipeline

This CL publishes add the Wan Animate pipepline.

  • Reused the existing Wan attention operator for face encoder cross attention.
  • Swept Flash Attention block-size configurations to identify the best inference setting.

Links

Performance

  • compile_time: 292.73833787906915
  • generation_time: 157.68515427410603

Configuration

  • cp: 8 (v6e8)
  • cfg: 1.0
  • prev_segments: 5
  • resolution: 1280x720
  • fps: 24
  • generated_frames: 77

@github-actions
Copy link
Copy Markdown

@csgoogle csgoogle marked this pull request as ready for review April 6, 2026 16:33
@csgoogle csgoogle requested a review from entrpn as a code owner April 6, 2026 16:33
@csgoogle csgoogle force-pushed the sagarchapara/wananimate-pipeline branch from 67233e9 to e281524 Compare April 13, 2026 08:49
@csgoogle csgoogle force-pushed the sagarchapara/wananimate-pipeline branch from e281524 to 349d080 Compare April 13, 2026 09:10
Comment thread src/maxdiffusion/schedulers/scheduling_unipc_multistep_flax.py
Comment thread src/maxdiffusion/schedulers/scheduling_unipc_multistep_flax.py
Comment thread assets/wan_animate/src_face.mp4 Outdated
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we move the assets to a public GCS path or use an existing hf dataset link?

Comment thread src/maxdiffusion/pipelines/wan/wan_pipeline_animate.py
Comment thread src/maxdiffusion/schedulers/scheduling_unipc_multistep_flax.py
f"{_frame_summary('mask', mask_video)}"
)

animate_settings = _get_animate_inference_settings(config)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also add lora support?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if there is a need for a separate generate script? Can we add this to existing generate_wan.py file?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Animate needs different lora for replacement, will add it in next pr.

For generate script, the inputs are different, we have ref image, video and replacement image, video... so I think it's better to have a different one.

@Perseus14
Copy link
Copy Markdown
Collaborator

Please resolve conflicts and enable support for diagnostics and profiling as in this PR

csgoogle added 2 commits May 6, 2026 15:46
… into sagarchapara/wananimate-pipeline

# Conflicts:
#	.gitignore
#	src/maxdiffusion/models/wan/transformers/transformer_wan_animate.py
Comment thread .gitignore
# RL pipelines may produce mp4 outputs
*.mp4
!assets/wan_animate/**/*.mp4
assets/wan_animate/
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's remove the folder assests/wan_animate all together and remove this

max_logging.log(f"Saved video to {video_path}")

if getattr(config, "enable_profiler", False):
if max_utils.profiler_enabled(config):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this support both ml-diagnostics and enable-profiler options?

reference_image_path: "assets/wan_animate/src_ref.png"
pose_video_path: "assets/wan_animate/src_pose.mp4"
face_video_path: "assets/wan_animate/src_face.mp4"
reference_image_path: ""
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we point to some default image path in huggingface?

@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 7, 2026

🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📋 Review Summary

The Pull Request introduces the Wan Animate pipeline, which includes the transformer model architecture, inference entry point, and necessary utilities. The implementation is comprehensive and follows the established patterns in the repository, including support for segment-based inference and parity with Diffusers.

🔍 General Feedback

  • Performance Optimization: The current implementation of the transformer re-encodes the face video frames during every denoising step. Since the face video is static throughout the inference process, this encoding can be pre-computed once per segment to significantly reduce redundant computation and speed up generation.
  • Compilation Efficiency: The generation script performs two full inference passes. For high-resolution video generation, this double work is expensive. Consider reducing the number of steps in the first (compile) pass.
  • Robustness: Added checks for optional inputs in the transformer to prevent potential runtime errors when face_pixel_values is not provided.
  • Code Quality: The reuse of the Wan attention operator and the integration with the existing configuration system is well-done. The use of nnx.scan for transformer blocks ensures memory efficiency during inference.

num_inference_steps=config.num_inference_steps,
mode=mode,
)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 The script performs two full inference passes. While the first pass is intended for compilation timing, it could be optimized by using a smaller number of steps (e.g., 1 or 2) or a dummy input to reduce total execution time, especially for high-resolution generations where inference is expensive.

Suggested change
# First pass (compile with minimal steps to save time)
compile_config = deepcopy(config)
compile_config.num_inference_steps = 1
_ = pipeline(
image=image,
pose_video=pose_video,
face_video=face_video,
background_video=background_video,
mask_video=mask_video,
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
segment_frame_length=animate_settings["segment_frame_length"],
prev_segment_conditioning_frames=animate_settings["prev_segment_conditioning_frames"],
motion_encode_batch_size=animate_settings["motion_encode_batch_size"],
guidance_scale=animate_settings["guidance_scale"],
num_inference_steps=1,
mode=mode,
)


query = self.to_q(hidden_states)
key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟠 Performance: The face video and motion encoding are performed during every denoising step. Since these depend only on the static face video and are independent of the noisy latents and timestep, they should be pre-computed outside the denoising loop to significantly improve performance.

Suggested change
value = self.to_v(encoder_hidden_states)
# 4. Batched Face & Motion Encoding (Pre-compute this outside the denoising loop if possible)
if face_pixel_values is not None:
_, face_channels, num_face_frames, face_height, face_width = face_pixel_values.shape
# Rearrange from (B, C, T, H, W) to (B*T, C, H, W)
face_pixel_values_reshaped = jnp.transpose(face_pixel_values, (0, 2, 1, 3, 4))
face_pixel_values_reshaped = jnp.reshape(face_pixel_values_reshaped, (-1, face_channels, face_height, face_width))
total_face_frames = face_pixel_values_reshaped.shape[0]
motion_encode_batch_size = motion_encode_batch_size or self.motion_encoder_batch_size
# Pad sequence if it doesn't divide evenly by encode_bs
pad_len = (motion_encode_batch_size - (total_face_frames % motion_encode_batch_size)) % motion_encode_batch_size
if pad_len > 0:
pad_tensor = jnp.zeros(
(pad_len, face_channels, face_height, face_width),
dtype=face_pixel_values_reshaped.dtype,
)
face_pixel_values_reshaped = jnp.concatenate([face_pixel_values_reshaped, pad_tensor], axis=0)
# Reshape into chunks for scan
num_chunks = face_pixel_values_reshaped.shape[0] // motion_encode_batch_size
face_chunks = jnp.reshape(
face_pixel_values_reshaped,
(
num_chunks,
motion_encode_batch_size,
face_channels,
face_height,
face_width,
),
)
# Use jax.lax.scan to iterate over chunks to save memory
def encode_chunk_fn(carry, chunk):
encoded_chunk = self.motion_encoder(chunk)
return carry, encoded_chunk
_, motion_vec_chunks = jax.lax.scan(encode_chunk_fn, None, face_chunks)
motion_vec = jnp.reshape(motion_vec_chunks, (-1, motion_vec_chunks.shape[-1]))
# Remove padding if added
if pad_len > 0:
motion_vec = motion_vec[:-pad_len]
motion_vec = jnp.reshape(motion_vec, (batch_size, num_face_frames, -1))
# Apply face encoder
motion_vec = self.face_encoder(motion_vec)
pad_face = jnp.zeros_like(motion_vec[:, :1])
motion_vec = jnp.concatenate([pad_face, motion_vec], axis=1)
else:
motion_vec = None

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants