Skip to content

Add sm87 to aarch64 CUDA wheel targets (Jetson Orin)#1939

Closed
neil-the-nowledgeable wants to merge 1 commit intobitsandbytes-foundation:mainfrom
neil-the-nowledgeable:fix/issue-1930
Closed

Add sm87 to aarch64 CUDA wheel targets (Jetson Orin)#1939
neil-the-nowledgeable wants to merge 1 commit intobitsandbytes-foundation:mainfrom
neil-the-nowledgeable:fix/issue-1930

Conversation

@neil-the-nowledgeable
Copy link
Copy Markdown
Contributor

Summary

Adds sm_87 (NVIDIA Jetson Orin family — Nano, NX, AGX) to the aarch64 CUDA wheel build targets. Without an explicit sm87 cubin, aarch64 wheels currently target sm75/sm80/sm90 only, and Jetson Orin users either get slow / unsupported paths or have to source-build bnb against -DCOMPUTE_CAPABILITY=87.

Closes #1930 (sm87 build target request)
Closes #1218 (Jetson Orin support)

Why explicit sm87 is needed

The closing comment on #1781 suggested that sm80 PTX should JIT to sm87 hardware via forward-compat. That's not how the bnb build setup actually distributes PTX — see CMake arch logic at CMakeLists.txt:226-230: only the latest capability gets a PTX entry, so a wheel built for sm75/sm80/sm90 ships PTX only for sm90. PTX forward-compat is upward-only — sm87 hardware cannot JIT from sm90+ PTX. Which means that Jetson Orin users effectively don't get GPU paths from current aarch64 wheels.

By adding sm87 to the build list emits an explicit sm_87 cubin alongside the existing arches, giving Jetson Orin users a working GPU path without source-build.

Wheel size impact

Measured via measure_sm87_size.sh (two clean builds — baseline at origin/main, then with the build-list change applied) on a Linux aarch64 host with CUDA 12.6.68 (cuda_12.6.r12.6/compiler.34714021_0), source HEAD a57d8e2:

build libbitsandbytes_cuda126.so size (bytes) size (MiB)
baseline (sm75;80;90) 5,710,520 5.45
with sm87 (sm75;80;87;90) 7,353,064 7.01
delta +1,642,544 +1.57 (+28.76%)

This is the cubin-only delta; PTX entries for sm90+ are unchanged. The ~1.6 MB increase covers the explicit sm_87 cubin that Jetson Orin hardware needs in the absence of forward-compat-from-sm90+-PTX.

CUDA 12.8/13.0 build paths (the other arms in build-cuda.sh) include additional sm100/sm120/sm121 entries; the absolute sizes there will differ but the per-arch delta-from-adding-sm87 is expected to be similar to the 1.6 MB measured here.

Test added

tests/test_linear4bit_sm87_multishape_regression.py — pytest reproducer for the multi-shape Linear4bit cold-start fault on sm_87 (#1936). The recipe is exactly the failing recipe documented in #1936:

  • shape order: A=(4096, 32768), B=(4096, 128256), C=(3584, 152064) — monotonically increasing by output-feature product
  • quant_type="nf4", quant_storage=torch.bfloat16, compute_dtype=torch.bfloat16, compress_statistics=True
  • batch size 1
  • no del / empty_cache / synchronize between iterations

Test skipped on all non-sm_87 hardware via torch.cuda.get_device_capability() != (8, 7) check.

** Be aware (as per the test docstring):** the fault is overwhelmingly cold-start-specific. Warm-state samples = 0% reboot (N=29+); cold-start samples = ~78% reboot (N=9 cumulative across two physically separate Jetson Orin Nano Super 8 GB units). pytest cannot capture the failure mode (system-level reboot, not Python exception or even SIGKILL) — absence of test output IS the regression signal in CI logs. CI runners targeting sm_87 should provide cold-state (e.g., reboot before test run) for the test to fire reliably.

#1936 follow-up (separate work)

The regression test verifies the recipe runs cleanly on the post-fix bnb. The actual #1936 close is a follow-up: git bisect v0.46.1..v0.49.2 against the regression test on a Jetson Orin to identify the responsible commit (or, if none, document the cold-start-race characterization). Will be handled in a separate comment on #1936.

Earlier framing on #1936 ("fixed in 0.49.2") was based on a warm-state sample and was retracted in a 2026-05-05 correction comment — bnb 0.49.2 also reboots at cold-start (N=1). The actual mechanism is a cold-start race in the bnb-NF4 dequant kernel + Tegra driver path; warmup, recipe-axis changes, and lower power modes all shift timing past the cold-start window without patching a specific source bug.

Verification

  • 13-test orthogonal-axis bisection on bnb 0.46.1 + sm_87 + MAXN_SUPER confirms the fault is highly recipe-specific. Six axes (shape order, quant_type, quant_storage, compute_dtype, double_quant, hygiene) are each independently sufficient to prevent the fault.
  • All 6 order permutations of {A, B, C} tested: only ABC (the unique strictly-monotonic-increasing-by-product order) reboots; ACB / BAC / BCA / CAB / CBA all pass cleanly.
  • Power-mode reruns: at nvpmodel mode 1 (25W) and mode 0 (15W), the same recipe passes — confirming clock/timing-sensitivity. Fault occurs at MAXN_SUPER only.
  • Cold-start vs warm-state characterization: 78% reboot at cold-start (N=9 across multiple binaries: cluster-venv 0.46.1, fresh-built 4bca844, fresh-built 0.49.2), 0% at warm-state (N=29+). Build-environment is not a fault axis.
  • Warmup mitigation: a single 256×256 NF4 forward executed as the first GPU op after boot closes the cold-start race window (N=3, all pass at full MAXN_SUPER recipe with no other change).

Adds sm_87 (NVIDIA Jetson Orin: Nano / NX / AGX) to the aarch64
build_capability list in .github/scripts/build-cuda.sh and documents
the addition in installation.mdx.

Why an explicit cubin is needed: the CMake arch logic at
CMakeLists.txt:226-230 only emits PTX for the latest capability.
sm87 hardware can't JIT from sm90+ PTX (forward-compat is upward-
only), so aarch64 wheels currently targeting sm75/sm80/sm90 ship
PTX only for sm90 and Jetson Orin users fall back to slow or
unsupported paths. This rebuts the "sm80 should cover sm87"
reasoning that closed bitsandbytes-foundation#1781.

Wheel size impact (measured on Linux aarch64, CUDA 12.6.68, source
HEAD a57d8e2):
  baseline (sm75;80;90):         5,710,520 bytes (5.45 MiB)
  with sm87 (sm75;80;87;90):     7,353,064 bytes (7.01 MiB)
  delta:                       +1,642,544 bytes (+1.57 MiB, +28.76%)

Adds tests/test_linear4bit_sm87_multishape_regression.py — pytest
reproducer for the multi-shape Linear4bit cold-start fault on
sm_87 (bitsandbytes-foundation#1936). The test runs the historical failing recipe (NF4 +
bf16 quant_storage + bf16 compute + double_quant + ABC shape order
+ no hygiene + batch=1) at sm_87 cold-state. The fault is
cold-start-specific; the test docstring documents the warm/cold
distinction so CI runners can configure accordingly.

Closes bitsandbytes-foundation#1930
Closes bitsandbytes-foundation#1218

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@matthewdouglas
Copy link
Copy Markdown
Member

Hi @neil-the-nowledgeable, thanks for the PR. A few things to clarify here:

For CUDA binary compatibility, a cubin built for e.g. sm80 would be expected run on an sm86/sm87/sm89 device, as it is within the major family. It is documented here, among other places. This isn't really JIT from PTX but binary compatibility.

If not including the sm87 target in our build was the issue, your report in #1930 would have an error message around the lines of "no kernel image for device." Instead you got a missing symbol error, which points at something different being the problem. I think you can spot check me on this by building bitsandbytes for sm80 target only on your device, and see if it still runs, produces the same error, or a different one. I expect it will behave largely the same as if you built for sm87.

We build and distribute wheels for aarch64 using an aarch64-sbsa runner. On Jetson Orin platform, it needs JetPack, with a different CUDA toolkit and library dependencies, and different ABI. This difference would explain the symbol failure error message. If we built sm87 on our CI as per this change, I am not confident at all that it would work on your L4T Jetson Orin. Have you tried building on an aarch64-sbsa machine and then running it on device? I would expect the outcome to be very similar to what it is today.

I'm not planning to build an entirely separate wheel and distributing it for Orin at this time. So instead, we'd have to consider bundling into the existing aarch64 wheel, and that might require cross-compiling for it with the right L4T toolkit, and including it as an extra copy of each binary, e.g. a separate libbitsandbytes_cuda126_l4t.so. Then we would have to know to select this version at runtime. That's a lot of non-trivial CI work and it's hard to do this with confidence without the real hardware. Unfortunately it's really not high on the priority list.

Aside from just Orin, this same issue would be present for other L4T arches like Xaiver/sm72. Thor is a little interesting: for CUDA 12, where it is sm101 and still has the L4T stack, it has the same problem. It should be fine, however, with aarch64-sbsa package in CUDA 13.

For now, we can't really do much here other than recommend either building from source or using a third-party package like this one.

As for the tests, it's not super useful IMHO. We don't have access to the hardware whether CI or not, so it's not going to provide any signal to us. It's more useful in the context of the issue report about it.

@matthewdouglas matthewdouglas added Build CI/CD CUDA Issues and PRs related to the CUDA backend, excluding installation/support help. labels May 7, 2026
@neil-the-nowledgeable
Copy link
Copy Markdown
Contributor Author

neil-the-nowledgeable commented May 7, 2026

Hi @matthewdouglas, thanks for the thoughtful pushback.

Binary compat — I was wrong. I conflated PTX forward-compat with binary-compat within the Ampere/Ada major family. Per the docs you linked, sm_80 SASS runs directly on sm_86 / sm_87 / sm_89 hardware — no JIT needed. The PR's premise that "Jetson Orin needs an explicit sm_87 cubin because sm_80 can't run on sm_87" was incorrect.

The error type is diagnostic. The Error named symbol not found at line 233 in /src/csrc/ops.cu from #1930 is a CUDA-runtime symbol-resolution error, not a kernel-image-for-device error. That points squarely at toolchain/library linkage, not arch. I should have caught this in the original report, sorry about that.

Empirical confirmation of your hypothesis. I ran the spot-check you suggested. Source-built bnb on the Jetson with -DCOMPUTE_CAPABILITY=80 only (no sm_87, no other arches), verified the resulting cubin via cuobjdump --list-elf:

ELF file 1: libbitsandbytes_cuda126.1.sm_80.cubin
PTX entries: (none)

sm_80 SASS only, zero PTX, no sm_87 fallback path of any kind. Overlaid this cubin into the venv (renamed to match the cu128 torch loader's expected filename) and ran the original #1930 reproducer plus a couple of broader checks on sm_87 hardware:

Test Result
import bitsandbytes PASS — recognized device cap (8, 7)
quantize_4bit(torch.randn(16,16, device='cuda', dtype=torch.bfloat16)) (the #1930 repro) PASS — output shape [128, 1] uint8, valid QuantState
Linear4bit(256, 256, nf4, bf16/bf16).forward() PASS — finite output, 2.3 ms post-warmup
quantize_4bitdequantize_4bit roundtrip on [64, 64] PASS — max|err| 0.45, mean|err| 0.07 (typical NF4)

No "named symbol not found" error. No "no kernel image for device" error. No Tegra reboot under this single-shape test recipe (the #1936 multi-shape cold-start fault is a separate matter, scoped to its own thread). sm_80 cubin runs binary-compat-clean on sm_87 hardware as you described.

This empirically rules out the "missing sm_87 in the build matrix" framing of #1930. The actual fault has to be the aarch64-sbsa CI build linking against CUDA libraries / ABI that don't match the L4T runtime on Orin — exactly your hypothesis. Adding sm_87 to your CI build target wouldn't help because the resulting cubin would still link against aarch64-sbsa libraries.

On the test: agreed it's not useful in your CI without the hardware. I'll pull it; the multi-shape Linear4bit cold-start reproducer lives more naturally in the #1936 thread where it has context.

Pivot proposal. Given the above, I'll close this PR. The right follow-up is documentation-only — an installation.mdx section that:

  1. Calls out the L4T-vs-aarch64-sbsa toolchain distinction explicitly (so the next person hitting Error named symbol not found doesn't chase the arch list).
  2. Documents the source-build recipe that works on-device (cmake -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY=87 . && make).
  3. Points at the third-party JAL package you linked as the prebuilt option.

I'll file that as a separate small docs-only PR, referencing this thread for context. Closing this one.

Apologies for the noise, and thank you very much for your patient teaching / feedback . Your understanding of the failure mode was correct from the first comment.

@neil-the-nowledgeable
Copy link
Copy Markdown
Contributor Author

Following up with a few more details, supporting your insight. (PR staying closed)

a 20-line LD_PRELOAD shim hooking cudaLaunchKernel (cuda-gdb / compute-sanitizer don't work on Jetson Orin — libcudadebugger.so.1 segfaults / "GPU debugging features are disabled"). Logs every kernel launch with dladdr-resolved name + return code. Ran the original #1930 reproducer (quantize_4bit(torch.randn(16, 16, dtype=torch.bfloat16))) under unsloth-jetson venv (cu126 torch 2.5) with the aarch64-sbsa PyPI wheel bitsandbytes-0.46.1 PYTHONPATH-shadowed.

Exact failing kernel:

[hook 1] cudaLaunchKernel(... 'at::native::distribution_elementwise_grid_stride_kernel<...>') = 0  // torch.randn
[hook 2] cudaLaunchKernel(0xffff013f2430
  '_Z18kQuantizeBlockwiseI13__nv_bfloat16Li64ELi2ELi0ELi1EEvPfPT_S1_PhS1_ii',
  grid=(4,1,1), block=(32,1,1)) = 500 (named symbol not found)

Demangled: kQuantizeBlockwise<__nv_bfloat16, 64, 2, 0, 1>(float*, __nv_bfloat16*, float*, unsigned char*, float*, int, int) — the FP4 (DATA_TYPE=1) blocksize=64 specialization for bfloat16 input.

The kicker: this exact mangled symbol IS present in the aarch64-sbsa wheel's sm_80 cubin AND in our locally-built L4T sm_87 cubin. Verified with cuobjdump --dump-elf-symbols. Both cubins have identical kQuantizeBlockwise template-instantiation inventories — 66 variants each, exact same set across bfloat16 / __half / float × all (blocksize, items, *, DATA_TYPE) combinations. Same kernels. Same arch class (sm_80 → sm_87 binary-compat per the docs you linked, confirmed earlier in this thread). Yet the SBSA cubin fails on sm_87 hardware while our L4T sm_87 cubin runs fine.

The actual divergence is at the host↔cubin module-binding metadata layer:

Probe aarch64-sbsa wheel L4T source-build
__cudaRegisterFunction calls in the .so 4275 1605 (~3× = multi-arch effect)
Mangled-name strings in the .so 614 596
Kernel discriminators (e.g., _Z12kgetRowStats...$N) $301, $310, $307, $316, ... $306, $312, ...
Module ID strings def_module_id_str_f8c82de5_6_ops_cu_de3d8167_709 def_module_id_str_06014870_6_ops_cu_f7931056_14894
CUB internal namespace tag cub::CUB_200500_750_800_900_NS::* cub::CUB_200500_870_NS::*

These are nvcc-emitted, build-time bindings that tie the host-side stub pointer to a specific cubin module. The L4T runtime can't reconcile them when the host code was compiled by aarch64-sbsa nvcc and the cubin was emitted by aarch64-sbsa nvcc — even at matching CUDA-toolkit versions (12.6 on both sides).

Direct implication for this PR's premise. The PR proposed adding sm_87 to the aarch64-sbsa CI build matrix. Given the above, even if you had merged it, it wouldn't have fixed #1930 / #1218. The wheel would gain an sm_87 cubin, but the host code would still be aarch64-sbsa-nvcc-emitted with aarch64-sbsa-nvcc-tagged module IDs. The L4T runtime would still fail to bind. We'd still get cudaLaunchKernel = 500 on the very first bnb kernel launch.

The remedy space narrows to exactly what you (matthewdouglas) described: cross-compile the entire bnb (host + cubin) with L4T's nvcc, bundle as a separate libbitsandbytes_cuda126_l4t.so, runtime-select. CMake-only changes (e.g., normalizing CUB namespace tags) wouldn't fix the $N discriminators or def_module_id_str_* tags — those are nvcc-internal, not user-controllable. So your "really not high on the priority list" assessment is well supported (and exceedingly polite ;-) .

Thanks again for the thoughtful feedback.

matthewdouglas pushed a commit that referenced this pull request May 7, 2026
…#1941)

Adds a WARNING callout after the Linux aarch64 row of the PyPI build-targets
table, explaining that:

1. Wheels are built on aarch64-sbsa runners (standard CUDA Toolkit), not the
   L4T / JetPack runtime that Jetson Orin / Xavier / Thor (on CUDA 12) use.
2. The mismatch surfaces as 'Error named symbol not found in /src/csrc/ops.cu'
   on the first CUDA op — a symbol-resolution error, NOT a kernel-image-for-
   device error. The cubins ARE binary-compatible with the device per
   Ampere-family binary compat (sm_80 SASS runs on sm_87 hardware natively).
3. Working options on Jetson: on-device source build, or third-party prebuilt
   from Jetson AI Lab.

References #1218 and #1930 for the original error reports, and #1939 for the
empirical confirmation that the fault is the toolchain delta, not the arch
list (sm_80-only cubin built on-device runs cleanly on sm_87 hardware).

Co-authored-by: neil-the-nowledgable <254185769+neil-the-nowledgable@users.noreply.github.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Build CI/CD CUDA Issues and PRs related to the CUDA backend, excluding installation/support help.

Projects

None yet

2 participants