Skip to content

Commit 237d318

Browse files
yiyixuxusayakpaul
andauthored
Apply suggestions from code review
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 54f008e commit 237d318

3 files changed

Lines changed: 4 additions & 10 deletions

File tree

src/diffusers/models/autoencoders/autoencoder_kl_hunyuanvideo15.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -215,8 +215,6 @@ class HunyuanVideo15Downsample(nn.Module):
215215
def __init__(self, in_channels: int, out_channels: int, add_temporal_downsample: bool = True):
216216
super().__init__()
217217
factor = 2 * 2 * 2 if add_temporal_downsample else 1 * 2 * 2
218-
assert out_channels % factor == 0
219-
# self.conv = Conv3d(in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1)
220218
self.conv = HunyuanVideo15CausalConv3d(in_channels, out_channels // factor, kernel_size=3)
221219

222220
self.add_temporal_downsample = add_temporal_downsample
@@ -531,7 +529,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
531529

532530
hidden_states = self.mid_block(hidden_states)
533531

534-
# short_cut = rearrange(hidden_states, "b (c r) f h w -> b c r f h w", r=self.group_size).mean(dim=2)
535532
batch_size, _, frame, height, width = hidden_states.shape
536533
short_cut = hidden_states.view(batch_size, -1, self.group_size, frame, height, width).mean(dim=2)
537534

@@ -546,7 +543,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
546543

547544
class HunyuanVideo15Decoder3D(nn.Module):
548545
r"""
549-
Causal decoder for 3D video-like data used for HunyuanImage-2.1 Refiner.
546+
Causal decoder for 3D video-like data used for HunyuanImage-1.5 Refiner.
550547
"""
551548

552549
def __init__(

src/diffusers/models/transformers/transformer_hunyuan_video15.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -184,10 +184,7 @@ class HunyuanVideo15TimeEmbedding(nn.Module):
184184
The dimension of the output embedding.
185185
"""
186186

187-
def __init__(
188-
self,
189-
embedding_dim: int,
190-
):
187+
def __init__(self, embedding_dim: int):
191188
super().__init__()
192189

193190
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
@@ -362,7 +359,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
362359
rope_sizes = [num_frames // self.patch_size_t, height // self.patch_size, width // self.patch_size]
363360

364361
axes_grids = []
365-
for i in range(3):
362+
for i in range(len(rope_sizes)):
366363
# Note: The following line diverges from original behaviour. We create the grid on the device, whereas
367364
# original implementation creates it on CPU and then moves it to device. This results in numerical
368365
# differences in layerwise debugging outputs, but visually it is the same.

src/diffusers/pipelines/hunyuan_video1_5/image_processor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def generate_crop_size_list(base_size=256, patch_size=16, max_ratio=4.0):
3434
return crop_size_list
3535

3636

37-
# copied fromhttps://github.com/Tencent-Hunyuan/HunyuanVideo-1.5/blob/main/hyvideo/utils/data_utils.py#L38
37+
# copied from https://github.com/Tencent-Hunyuan/HunyuanVideo-1.5/blob/main/hyvideo/utils/data_utils.py#L38
3838
def get_closest_ratio(height: float, width: float, ratios: list, buckets: list):
3939
"""
4040
Get the closest ratio in the buckets.

0 commit comments

Comments
 (0)