Skip to content

Commit b6f346b

Browse files
tp5uiucclaude
andcommitted
fix(rtx): add WAR to fall back grouped 3D deconvolutions to PyTorch
Grouped 3D transposed convolutions (ConvTranspose3d with groups > 1) crash on TensorRT-RTX. This adds a convolution_capability_validator that detects these ops and rejects them from TRT conversion, causing the partitioner to keep them in PyTorch while other ops remain on TRT. Also renames depthwise_bf16_validator to convolution_capability_validator to reflect its broader scope, and removes the blanket skip on all 3D deconv tests — non-grouped cases now run through TRT on RTX. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent df261d5 commit b6f346b

3 files changed

Lines changed: 84 additions & 17 deletions

File tree

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2754,39 +2754,55 @@ def aten_ops_le(
27542754
)
27552755

27562756

2757-
def depthwise_bf16_validator(
2757+
def convolution_capability_validator(
27582758
node: Node, settings: Optional[CompilationSettings] = None
27592759
) -> bool:
2760-
"""Reject depthwise conv/deconv with BF16 on TensorRT-RTX.
2760+
"""Reject unsupported convolution variants on TensorRT-RTX.
27612761
2762-
TensorRT-RTX does not support depthwise convolutions in BF16. Returning
2763-
False causes the partitioner to fall back to PyTorch for these specific
2764-
nodes, while all other convolutions remain on TRT.
2762+
Falls back to PyTorch for:
2763+
1. Depthwise convolutions in BF16 (no kernel support on TRT-RTX).
2764+
2. Grouped 3D deconvolutions (crash on TRT-RTX).
27652765
"""
27662766
if not ENABLED_FEATURES.tensorrt_rtx:
27672767
return True
2768-
# Check if the input tensor is BF16 (via FX node metadata)
2769-
input_node = node.args[0]
2770-
input_meta = getattr(input_node, "meta", {}).get("tensor_meta")
2771-
if input_meta is None or input_meta.dtype != torch.bfloat16:
2768+
2769+
if (input_meta := getattr(node.args[0], "meta", {}).get("tensor_meta")) is None:
27722770
return True
2771+
27732772
groups = args_bounds_check(node.args, 8)
2774-
if groups is not None and groups > 1:
2775-
weight_node = node.args[1]
2776-
weight_meta = getattr(weight_node, "meta", {}).get("tensor_meta")
2777-
if weight_meta is not None and groups == weight_meta.shape[0]:
2773+
is_grouped = groups is not None and groups > 1
2774+
is_transposed = bool(args_bounds_check(node.args, 6))
2775+
is_3d = input_meta.shape is not None and len(input_meta.shape) == 5
2776+
is_bf16 = input_meta.dtype == torch.bfloat16
2777+
2778+
# WAR: Grouped 3D deconvolutions crash on TRT-RTX (any dtype).
2779+
if is_transposed and is_grouped and is_3d:
2780+
_LOGGER.debug(
2781+
"Grouped 3D deconvolution '%s' (groups=%d) is not supported on "
2782+
"TensorRT-RTX. Falling back to PyTorch for this layer.",
2783+
node.name,
2784+
groups,
2785+
)
2786+
return False
2787+
2788+
# WAR: Depthwise convolutions in BF16 are not supported on TRT-RTX.
2789+
if is_bf16 and is_grouped:
2790+
if (
2791+
weight_meta := getattr(node.args[1], "meta", {}).get("tensor_meta")
2792+
) is not None and groups == weight_meta.shape[0]:
27782793
_LOGGER.debug(
27792794
"Depthwise convolution '%s' with BF16 is not supported on "
27802795
"TensorRT-RTX. Falling back to PyTorch for this layer.",
27812796
node.name,
27822797
)
27832798
return False
2799+
27842800
return True
27852801

27862802

27872803
@dynamo_tensorrt_converter(
27882804
torch.ops.aten.convolution.default,
2789-
capability_validator=depthwise_bf16_validator,
2805+
capability_validator=convolution_capability_validator,
27902806
supports_dynamic_shapes=True,
27912807
)
27922808
@enforce_tensor_types(

tests/py/dynamo/conversion/test_deconvolution_aten.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -227,9 +227,6 @@ def forward(self, x):
227227
),
228228
]
229229
)
230-
@unittest.skipIf(
231-
torch_tensorrt.ENABLED_FEATURES.tensorrt_rtx, "TensorRT-RTX has bug on deconv3d"
232-
)
233230
def test_deconv3d(
234231
self,
235232
_,
@@ -241,6 +238,9 @@ def test_deconv3d(
241238
bias=True,
242239
output_padding=0,
243240
):
241+
if groups > 1 and torch_tensorrt.ENABLED_FEATURES.tensorrt_rtx:
242+
self.skipTest("Grouped 3D deconvolutions fall back to PyTorch on TRT-RTX")
243+
244244
class TestModule(torch.nn.Module):
245245
def __init__(self):
246246
super().__init__()

tests/py/dynamo/models/test_models.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -609,3 +609,54 @@ def forward(self, x):
609609

610610
# Clean up model env
611611
torch._dynamo.reset()
612+
613+
614+
@pytest.mark.unit
615+
@unittest.skipIf(
616+
not torchtrt.ENABLED_FEATURES.tensorrt_rtx,
617+
"Grouped 3D deconv fallback WAR is TensorRT-RTX specific",
618+
)
619+
def test_grouped_deconv3d_fallback(ir):
620+
"""Grouped 3D deconvolutions fall back to PyTorch on TRT-RTX.
621+
622+
The convolution_capability_validator rejects grouped ConvTranspose3d ops
623+
so that the partitioner keeps them in PyTorch while other ops run on TRT.
624+
"""
625+
626+
class MyModule(torch.nn.Module):
627+
def __init__(self):
628+
super().__init__()
629+
self.conv = torch.nn.Conv3d(3, 16, 3, padding=1)
630+
self.relu = torch.nn.ReLU()
631+
self.deconv = torch.nn.ConvTranspose3d(16, 16, 3, padding=1, groups=16)
632+
633+
def forward(self, x):
634+
out = self.conv(x)
635+
out = self.relu(out)
636+
out = self.deconv(out)
637+
return out
638+
639+
model = MyModule().eval().cuda()
640+
input = torch.randn((1, 3, 16, 16, 16), device="cuda")
641+
642+
compile_spec = {
643+
"inputs": [torchtrt.Input(input.shape, dtype=torch.float32)],
644+
"device": torchtrt.Device("cuda:0"),
645+
"enabled_precisions": {torch.float32},
646+
"ir": ir,
647+
"pass_through_build_failures": True,
648+
"min_block_size": 1,
649+
"cache_built_engines": False,
650+
"reuse_cached_engines": False,
651+
}
652+
653+
trt_mod = torchtrt.compile(model, **compile_spec)
654+
cos_sim = cosine_similarity(model(input), trt_mod(input))
655+
656+
assertions.assertTrue(
657+
cos_sim > COSINE_THRESHOLD,
658+
msg=f"Grouped 3D deconv fallback model TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
659+
)
660+
661+
# Clean up model env
662+
torch._dynamo.reset()

0 commit comments

Comments
 (0)