Skip to content

aarch64 prebuilt wheels do not include sm_87 — Jetson Orin devices fail at first CUDA kernel launch #1930

@neil-the-nowledgeable

Description

@neil-the-nowledgeable

bitsandbytes — Issue: aarch64 prebuilt wheels missing sm_87 (Jetson Orin)

Proposed issue title

aarch64 prebuilt wheels do not include sm_87 — Jetson Orin devices fail at first CUDA kernel launch

Proposed issue body

Environment

  • Device: NVIDIA Jetson Orin Nano Super
  • OS: NVIDIA JetPack 6.2 (Linux for Tegra, aarch64)
  • CUDA: 12.6 (JetPack-bundled driver)
  • Compute capability: sm_87
  • Python: 3.10.12
  • torch: 2.5.0a0+872d972e41.nv24.08 (NVIDIA JetPack wheel)
  • bitsandbytes: 0.46.1 (PyPI aarch64 wheel)

What happened

Installing the prebuilt aarch64 wheel from PyPI succeeds:

pip install bitsandbytes==0.46.1
# Successfully installed bitsandbytes-0.46.1

import bitsandbytes also succeeds. The failure occurs at the first CUDA kernel launch — for example, quantize_4bit:

import torch
from bitsandbytes.functional import quantize_4bit
x = torch.randn(16, 16, device="cuda", dtype=torch.bfloat16)
quantize_4bit(x)
# RuntimeError: Error named symbol not found at line 233 in file /src/csrc/ops.cu

What I expected

Either (a) the wheel works on Jetson Orin, or (b) the wheel cleanly refuses to load with a message pointing at the arch-support matrix.

Root cause

The aarch64 prebuilt wheels target sm75, sm80, sm90 (CUDA 11.8-12.6) and sm75, sm80, sm90, sm100, sm110, sm120, sm121 (CUDA 12.8-13.0) per docs/source/installation.mdx. Jetson Orin devices (Orin Nano, Orin NX, AGX Orin) report sm_87, which is not in either set. The compiled CUDA kernels reference symbols not present in the JetPack driver's symbol table at the arch they were built for, producing the runtime-only error above.

Workaround — source build

Building from source with -DCOMPUTE_CAPABILITY=87 produces a working wheel:

git clone --depth 1 --branch 0.46.1 \
    https://github.com/bitsandbytes-foundation/bitsandbytes.git
cd bitsandbytes
PATH=/usr/local/cuda-12.6/bin:$PATH \
    cmake -B build . -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY=87
PATH=/usr/local/cuda-12.6/bin:$PATH \
    cmake --build build -j4
pip install .

Build takes ~6 minutes on an Orin Nano Super. After installation, the 4-bit quantize/dequantize roundtrip runs cleanly, and QLoRA training with transformers/peft/trl works end-to-end.

The source-built kernel path has been behaviorally validated, not just tested for stability. On a 16-problem held-out logic-reasoning benchmark (Carroll-16):

  • TinyLlama 1.1B at 4-bit NF4 scored within 1 problem of the same model's Ollama Q4_K_M reference;
  • Qwen2.5-3B-Instruct at 4-bit NF4 scored 93.75% keyword / 0.418 judge composite (highest non-reasoning-model result recorded on the benchmark);
  • A same-stack 4-bit-vs-bf16 training A/B produced training losses within 0.4% and downstream adapter scores within one-problem noise.

The source-built wheel produces numerically correct outputs at both 1B and 3B across two Llama-family architectures — not just "doesn't crash."

Suggested resolution options

In increasing order of effort for maintainers:

  1. Docs update — A paired PR (draft here) adds a "NVIDIA Jetson (sm_87) — source build required" section to docs/source/installation.mdx so users Googling this error find a canonical answer. Recommended either way.
  2. Arch-support matrix clarification — Explicitly name sm_87 in the matrix as "source build required" so the omission is intentional and visible.
  3. Add sm_87 to aarch64 CI wheel matrix — If CI capacity permits, building the aarch64 wheels for sm75;sm80;sm87;sm90 (and the 13.0 equivalent) would let Jetson Orin users pip install bitsandbytes directly. Jetson is widely deployed for edge-ML applications where on-device QLoRA fine-tuning is a real capability.

Reproduction

Above commands; error surfaces at any call into bitsandbytes.functional that launches a CUDA kernel (not at import).

Additional data points

If you want, I can test and report:

  • Other bitsandbytes versions on the same environment (0.45.x if wheels exist for aarch64)
  • Performance of the source-built sm_87 wheel on a standard QLoRA benchmark (TinyLlama, Llama-3.2-3B, similar) so the "what do you get after source-build" answer is concrete

How to submit (operator instructions)

  1. Open https://github.com/bitsandbytes-foundation/bitsandbytes/issues/new?template=bug_report.md
  2. Paste the issue body above
  3. Apply the bug + jetson labels if available
  4. If the paired docs PR is already open, cross-link

Metadata

Metadata

Assignees

No one assigned

    Labels

    CUDAIssues and PRs related to the CUDA backend, excluding installation/support help.aarch64

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions