Skip to content

Commit 7d6fda7

Browse files
committed
fix: prevent autocast crash in WAN model addcmul ops
Wraps torch.addcmul calls in WAN attention blocks with autocast-disabled context to prevent 'Unexpected floating ScalarType in at::autocast::prioritize' RuntimeError. This occurs when upstream nodes (e.g. SAM3) leave CUDA autocast enabled - PyTorch 2.8's autocast promote dispatch for addcmul hits an unhandled dtype in the prioritize function. Uses torch.is_autocast_enabled(device_type) (non-deprecated API) and only applies the workaround when autocast is actually active (zero overhead otherwise).
1 parent 359559c commit 7d6fda7

1 file changed

Lines changed: 20 additions & 9 deletions

File tree

comfy/ldm/wan/model.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,17 @@ def repeat_e(e, x):
174174
return torch.repeat_interleave(e, repeats + 1, dim=1)[:, :x.size(1)]
175175

176176

177+
def _addcmul(x, y, z):
178+
"""torch.addcmul wrapper that disables autocast to prevent
179+
'Unexpected floating ScalarType in at::autocast::prioritize' when
180+
upstream nodes (e.g. SAM3) leave CUDA autocast enabled."""
181+
device_type = x.device.type
182+
if torch.is_autocast_enabled(device_type):
183+
with torch.autocast(device_type=device_type, enabled=False):
184+
return torch.addcmul(x, y, z)
185+
return torch.addcmul(x, y, z)
186+
187+
177188
class WanAttentionBlock(nn.Module):
178189

179190
def __init__(self,
@@ -242,10 +253,10 @@ def forward(
242253
# self-attention
243254
x = x.contiguous() # otherwise implicit in LayerNorm
244255
y = self.self_attn(
245-
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
256+
_addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
246257
freqs, transformer_options=transformer_options)
247258

248-
x = torch.addcmul(x, y, repeat_e(e[2], x))
259+
x = _addcmul(x, y, repeat_e(e[2], x))
249260
del y
250261

251262
# cross-attention & ffn
@@ -255,8 +266,8 @@ def forward(
255266
for p in patches["attn2_patch"]:
256267
x = p({"x": x, "transformer_options": transformer_options})
257268

258-
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
259-
x = torch.addcmul(x, y, repeat_e(e[5], x))
269+
y = self.ffn(_addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
270+
x = _addcmul(x, y, repeat_e(e[5], x))
260271
return x
261272

262273

@@ -371,7 +382,7 @@ def forward(self, x, e):
371382
else:
372383
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e.unsqueeze(2)).unbind(2)
373384

374-
x = (self.head(torch.addcmul(repeat_e(e[0], x), self.norm(x), 1 + repeat_e(e[1], x))))
385+
x = (self.head(_addcmul(repeat_e(e[0], x), self.norm(x), 1 + repeat_e(e[1], x))))
375386
return x
376387

377388

@@ -1453,17 +1464,17 @@ def forward(
14531464

14541465
# self-attention
14551466
y = self.self_attn(
1456-
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
1467+
_addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
14571468
freqs, transformer_options=transformer_options)
14581469

1459-
x = torch.addcmul(x, y, repeat_e(e[2], x))
1470+
x = _addcmul(x, y, repeat_e(e[2], x))
14601471

14611472
# cross-attention & ffn
14621473
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
14631474
if audio is not None:
14641475
x = self.audio_cross_attn_wrapper(x, audio, transformer_options=transformer_options)
1465-
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
1466-
x = torch.addcmul(x, y, repeat_e(e[5], x))
1476+
y = self.ffn(_addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
1477+
x = _addcmul(x, y, repeat_e(e[5], x))
14671478
return x
14681479

14691480
class DummyAdapterLayer(nn.Module):

0 commit comments

Comments
 (0)