Skip to content

fix(rtx): add WAR to fall back grouped 3D deconvolutions to PyTorch#4188

Merged
zewenli98 merged 1 commit intopytorch:mainfrom
tp5uiuc:grouped-deconv3d-war
Apr 17, 2026
Merged

fix(rtx): add WAR to fall back grouped 3D deconvolutions to PyTorch#4188
zewenli98 merged 1 commit intopytorch:mainfrom
tp5uiuc:grouped-deconv3d-war

Conversation

@tp5uiuc
Copy link
Copy Markdown
Contributor

@tp5uiuc tp5uiuc commented Apr 14, 2026

Description

Grouped 3D transposed convolutions (ConvTranspose3d with groups > 1) crash on TensorRT-RTX. This PR adds a workaround that detects these ops at partitioning time and falls them back to PyTorch, while all other ops remain on TRT.

Changes

aten_ops_converters.py

  • Renamed depthwise_bf16_validatorconvolution_capability_validator to reflect its broader scope
  • Added check: if transposed=True AND groups > 1 AND input is 5D (3D spatial), reject from TRT on RTX
  • Existing BF16 depthwise convolution WAR preserved
  • Uses walrus operator and boolean variables (is_grouped, is_transposed, is_3d, is_bf16) for readability

test_deconvolution_aten.py

  • Replaced blanket @unittest.skipIf(tensorrt_rtx) on all 3D deconv tests with a targeted self.skipTest() inside the test body for grouped cases only
  • Non-grouped 3D deconv cases (9 parametrized + 1 dynamic shape) now run through TRT on RTX

test_models.py

  • Added test_grouped_deconv3d_fallback: model-level test that verifies the full torch_tensorrt.compile → partitioner → PyTorch fallback path with accuracy checks for grouped 3D deconv

Test results (L40S, TRT-RTX nightly)

Suite Result
deconv3d converter (non-grouped) 10 passed
deconv3d converter (grouped) 2 skipped (correctly rejected by validator)
grouped deconv3d fallback model 1 passed (accuracy verified via cosine similarity)
BF16 regression (14 tests) 14 passed

Type of change

  • Bug fix (non-breaking change which fixes an issue)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@meta-cla meta-cla bot added the cla signed label Apr 14, 2026
@github-actions github-actions bot added component: tests Issues re: Tests component: conversion Issues re: Conversion stage component: core Issues re: The core compiler component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Apr 14, 2026
@github-actions github-actions bot requested a review from narendasan April 14, 2026 21:05
@tp5uiuc
Copy link
Copy Markdown
Contributor Author

tp5uiuc commented Apr 14, 2026

This PR is based on #4178 and should be merged after it.

@narendasan narendasan requested a review from zewenli98 April 14, 2026 23:44
@tp5uiuc tp5uiuc force-pushed the grouped-deconv3d-war branch from e0acfd4 to b6f346b Compare April 15, 2026 10:03
@tp5uiuc tp5uiuc marked this pull request as ready for review April 15, 2026 10:03
bias=True,
output_padding=0,
):
if groups > 1 and torch_tensorrt.ENABLED_FEATURES.tensorrt_rtx:
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test skips the capability validation and directly uses the TRTInterpreter, so I have skipped groups > 1 case for now. The newly added test_grouped_deconv3d_fallback in the test_models.py covers the fallback to pytorch route

@tp5uiuc tp5uiuc force-pushed the grouped-deconv3d-war branch from b6f346b to 7049608 Compare April 15, 2026 18:52
Grouped 3D transposed convolutions (ConvTranspose3d with groups > 1)
crash on TensorRT-RTX. This adds a convolution_capability_validator
that detects these ops and rejects them from TRT conversion, causing
the partitioner to keep them in PyTorch while other ops remain on TRT.

Also renames depthwise_bf16_validator to convolution_capability_validator
to reflect its broader scope, and removes the blanket skip on all 3D
deconv tests — non-grouped cases now run through TRT on RTX.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@tp5uiuc tp5uiuc force-pushed the grouped-deconv3d-war branch from 7049608 to 2eb0734 Compare April 16, 2026 17:43
Copy link
Copy Markdown
Collaborator

@zewenli98 zewenli98 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. feel free to merge

@zewenli98 zewenli98 merged commit c11c0c3 into pytorch:main Apr 17, 2026
80 of 82 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: core Issues re: The core compiler component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: tests Issues re: Tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants