Skip to content

Commit 50987b1

Browse files
authored
[CI] Fix BnB tests (#13481)
* update * update * update
1 parent b8aebf4 commit 50987b1

8 files changed

Lines changed: 64 additions & 39 deletions

File tree

src/diffusers/models/transformers/transformer_wan_animate.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,10 +445,14 @@ def __call__(
445445
# B --> batch_size, T --> reduced inference segment len, N --> face_encoder_num_heads + 1, C --> attn.dim
446446
B, T, N, C = encoder_hidden_states.shape
447447

448+
# Flatten T and N so the K/V projections see a 3D tensor; BnB int8 matmul only
449+
# accepts 2D/3D inputs and would otherwise fail on this 4D activation.
450+
encoder_hidden_states = encoder_hidden_states.flatten(1, 2) # [B, T, N, C] --> [B, T * N, C]
451+
448452
query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states)
449453

450454
query = query.unflatten(2, (attn.heads, -1)) # [B, S, H * D] --> [B, S, H, D]
451-
key = key.view(B, T, N, attn.heads, -1) # [B, T, N, H * D_kv] --> [B, T, N, H, D_kv]
455+
key = key.view(B, T, N, attn.heads, -1) # [B, T * N, H * D_kv] --> [B, T, N, H, D_kv]
452456
value = value.view(B, T, N, attn.heads, -1)
453457

454458
query = attn.norm_q(query)

src/diffusers/models/transformers/transformer_wan_vace.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ def forward(
331331
)
332332
if i in self.config.vace_layers:
333333
control_hint, scale = control_hidden_states_list.pop()
334-
hidden_states = hidden_states + control_hint * scale
334+
hidden_states = hidden_states + control_hint.to(hidden_states.device) * scale
335335
else:
336336
# Prepare VACE hints
337337
control_hidden_states_list = []
@@ -346,7 +346,7 @@ def forward(
346346
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
347347
if i in self.config.vace_layers:
348348
control_hint, scale = control_hidden_states_list.pop()
349-
hidden_states = hidden_states + control_hint * scale
349+
hidden_states = hidden_states + control_hint.to(hidden_states.device) * scale
350350

351351
# 6. Output norm, projection & unpatchify
352352
shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1)

tests/models/testing_utils/common.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,11 @@ def pretrained_model_kwargs(self) -> Dict[str, Any]:
205205
"""Additional kwargs to pass to from_pretrained (e.g., subfolder, variant)."""
206206
return {}
207207

208+
@property
209+
def torch_dtype(self) -> torch.dtype:
210+
"""Compute dtype used to build dummy inputs and cast inputs where needed."""
211+
return torch.float32
212+
208213
@property
209214
def output_shape(self) -> Optional[tuple]:
210215
"""Expected output shape for output validation tests."""

tests/models/testing_utils/quantization.py

Lines changed: 17 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -359,15 +359,7 @@ def _test_dequantize(self, config_kwargs):
359359
if isinstance(module, torch.nn.Linear):
360360
assert not self._is_module_quantized(module), f"Module {name} is still quantized after dequantize()"
361361

362-
# Get model dtype from first parameter
363-
model_dtype = next(model.parameters()).dtype
364-
365362
inputs = self.get_dummy_inputs()
366-
# Cast inputs to model dtype
367-
inputs = {
368-
k: v.to(model_dtype) if isinstance(v, torch.Tensor) and v.is_floating_point() else v
369-
for k, v in inputs.items()
370-
}
371363
output = model(**inputs, return_dict=False)[0]
372364
assert output is not None, "Model output is None after dequantization"
373365
assert not torch.isnan(output).any(), "Model output contains NaN after dequantization"
@@ -575,33 +567,28 @@ def test_bnb_original_dtype(self):
575567

576568
@torch.no_grad()
577569
def test_bnb_keep_modules_in_fp32(self):
578-
if not hasattr(self.model_class, "_keep_in_fp32_modules"):
579-
pytest.skip(f"{self.model_class.__name__} does not have _keep_in_fp32_modules")
570+
fp32_modules = getattr(self.model_class, "_keep_in_fp32_modules", None)
571+
if not fp32_modules:
572+
pytest.skip(f"{self.model_class.__name__} does not declare _keep_in_fp32_modules")
580573

581574
config_kwargs = BitsAndBytesConfigMixin.BNB_CONFIGS["4bit_nf4"]
582575

583-
original_fp32_modules = getattr(self.model_class, "_keep_in_fp32_modules", None)
584-
self.model_class._keep_in_fp32_modules = ["proj_out"]
585-
586-
try:
587-
model = self._create_quantized_model(config_kwargs)
576+
model = self._create_quantized_model(config_kwargs)
577+
model.to(torch_device)
588578

589-
for name, module in model.named_modules():
590-
if isinstance(module, torch.nn.Linear):
591-
if any(fp32_name in name for fp32_name in model._keep_in_fp32_modules):
592-
assert module.weight.dtype == torch.float32, (
593-
f"Module {name} should be FP32 but is {module.weight.dtype}"
594-
)
595-
else:
596-
assert module.weight.dtype == torch.uint8, (
597-
f"Module {name} should be uint8 but is {module.weight.dtype}"
598-
)
579+
for name, module in model.named_modules():
580+
if isinstance(module, torch.nn.Linear):
581+
if any(fp32_name in name for fp32_name in fp32_modules):
582+
assert module.weight.dtype == torch.float32, (
583+
f"Module {name} should be FP32 but is {module.weight.dtype}"
584+
)
585+
else:
586+
assert module.weight.dtype == torch.uint8, (
587+
f"Module {name} should be uint8 but is {module.weight.dtype}"
588+
)
599589

600-
inputs = self.get_dummy_inputs()
601-
_ = model(**inputs)
602-
finally:
603-
if original_fp32_modules is not None:
604-
self.model_class._keep_in_fp32_modules = original_fp32_modules
590+
inputs = self.get_dummy_inputs()
591+
_ = model(**inputs)
605592

606593
def test_bnb_modules_to_not_convert(self):
607594
"""Test that modules_to_not_convert parameter works correctly."""

tests/models/transformers/test_models_transformer_flux.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -159,21 +159,36 @@ def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
159159

160160
return {
161161
"hidden_states": randn_tensor(
162-
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
162+
(batch_size, height * width, num_latent_channels),
163+
generator=self.generator,
164+
device=torch_device,
165+
dtype=self.torch_dtype,
163166
),
164167
"encoder_hidden_states": randn_tensor(
165-
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
168+
(batch_size, sequence_length, embedding_dim),
169+
generator=self.generator,
170+
device=torch_device,
171+
dtype=self.torch_dtype,
166172
),
167173
"pooled_projections": randn_tensor(
168-
(batch_size, embedding_dim), generator=self.generator, device=torch_device
174+
(batch_size, embedding_dim),
175+
generator=self.generator,
176+
device=torch_device,
177+
dtype=self.torch_dtype,
169178
),
170179
"img_ids": randn_tensor(
171-
(height * width, num_image_channels), generator=self.generator, device=torch_device
180+
(height * width, num_image_channels),
181+
generator=self.generator,
182+
device=torch_device,
183+
dtype=self.torch_dtype,
172184
),
173185
"txt_ids": randn_tensor(
174-
(sequence_length, num_image_channels), generator=self.generator, device=torch_device
186+
(sequence_length, num_image_channels),
187+
generator=self.generator,
188+
device=torch_device,
189+
dtype=self.torch_dtype,
175190
),
176-
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
191+
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype).expand(batch_size),
177192
}
178193

179194

@@ -320,6 +335,10 @@ def pretrained_model_name_or_path(self):
320335
class TestFluxTransformerBitsAndBytes(FluxTransformerTesterConfig, BitsAndBytesTesterMixin):
321336
"""BitsAndBytes quantization tests for Flux Transformer."""
322337

338+
@property
339+
def torch_dtype(self):
340+
return torch.float16
341+
323342

324343
class TestFluxTransformerQuanto(FluxTransformerTesterConfig, QuantoTesterMixin):
325344
"""Quanto quantization tests for Flux Transformer."""

tests/models/transformers/test_models_transformer_wan.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,11 +91,13 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
9191
(batch_size, num_channels, num_frames, height, width),
9292
generator=self.generator,
9393
device=torch_device,
94+
dtype=self.torch_dtype,
9495
),
9596
"encoder_hidden_states": randn_tensor(
9697
(batch_size, sequence_length, text_encoder_embedding_dim),
9798
generator=self.generator,
9899
device=torch_device,
100+
dtype=self.torch_dtype,
99101
),
100102
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
101103
}

tests/models/transformers/test_models_transformer_wan_animate.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,27 +113,32 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
113113
(batch_size, 2 * num_channels + 4, num_frames + 1, height, width),
114114
generator=self.generator,
115115
device=torch_device,
116+
dtype=self.torch_dtype,
116117
),
117118
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
118119
"encoder_hidden_states": randn_tensor(
119120
(batch_size, sequence_length, text_encoder_embedding_dim),
120121
generator=self.generator,
121122
device=torch_device,
123+
dtype=self.torch_dtype,
122124
),
123125
"encoder_hidden_states_image": randn_tensor(
124126
(batch_size, clip_seq_len, clip_dim),
125127
generator=self.generator,
126128
device=torch_device,
129+
dtype=self.torch_dtype,
127130
),
128131
"pose_hidden_states": randn_tensor(
129132
(batch_size, num_channels, num_frames, height, width),
130133
generator=self.generator,
131134
device=torch_device,
135+
dtype=self.torch_dtype,
132136
),
133137
"face_pixel_values": randn_tensor(
134138
(batch_size, 3, inference_segment_length, face_height, face_width),
135139
generator=self.generator,
136140
device=torch_device,
141+
dtype=self.torch_dtype,
137142
),
138143
}
139144

tests/models/transformers/test_models_transformer_wan_vace.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,16 +96,19 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
9696
(batch_size, num_channels, num_frames, height, width),
9797
generator=self.generator,
9898
device=torch_device,
99+
dtype=self.torch_dtype,
99100
),
100101
"encoder_hidden_states": randn_tensor(
101102
(batch_size, sequence_length, text_encoder_embedding_dim),
102103
generator=self.generator,
103104
device=torch_device,
105+
dtype=self.torch_dtype,
104106
),
105107
"control_hidden_states": randn_tensor(
106108
(batch_size, vace_in_channels, num_frames, height, width),
107109
generator=self.generator,
108110
device=torch_device,
111+
dtype=self.torch_dtype,
109112
),
110113
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
111114
}

0 commit comments

Comments
 (0)