Skip to content

Linear4bit._save_to_state_dict writes QuantState keys but no _load_from_state_dict consumes them (asymmetric serialization) #1946

@neil-the-nowledgeable

Description

@neil-the-nowledgeable

Summary

bitsandbytes.nn.Linear4bit overrides _save_to_state_dict (bitsandbytes/nn/modules.py:593) to write QuantState components alongside the packed weight:

def _save_to_state_dict(self, destination, prefix, keep_vars):
    super()._save_to_state_dict(destination, prefix, keep_vars)  # weight + bias
    if getattr(self.weight, "quant_state", None) is not None:
        for k, v in self.weight.quant_state.as_dict(packed=True).items():
            destination[prefix + "weight." + k] = v if keep_vars else v.detach()

Resulting state_dict keys for a Linear4bit (with compress_statistics=True):

  • weight (packed 4-bit tensor)
  • weight.absmax
  • weight.quant_map
  • weight.nested_absmax
  • weight.nested_quant_map
  • weight.quant_state.bitsandbytes__nf4 (or __fp4)

Linear4bit does not define a corresponding _load_from_state_dict. It inherits nn.Linear._load_from_state_dict, which only consumes weight and bias. The QuantState keys land in unexpected_keys during load. With strict=True (the model.load_state_dict() default) this raises RuntimeError.

This is asymmetric relative to Linear8bitLt, which defines both _save_to_state_dict (line 1095) and _load_from_state_dict (line 1119), with the latter explicitly walking unexpected_keys to consume the SCB tensor and remove it from the unexpected list.

Reproducer

import torch
import bitsandbytes as bnb

src = bnb.nn.Linear4bit(64, 64, bias=False, quant_type='nf4',
                       compute_dtype=torch.bfloat16,
                       compress_statistics=True)
src = src.to('cuda')  # triggers quantize, populates quant_state

sd = src.state_dict()
print("state_dict keys:", list(sd.keys()))
# ['weight', 'weight.absmax', 'weight.quant_map', 'weight.nested_absmax',
#  'weight.nested_quant_map', 'weight.quant_state.bitsandbytes__nf4']

dst = bnb.nn.Linear4bit(64, 64, bias=False, quant_type='nf4',
                       compute_dtype=torch.bfloat16,
                       compress_statistics=True)
dst = dst.to('cuda')

# Fails with strict=True (the default):
dst.load_state_dict(sd)
# RuntimeError: Error(s) in loading state_dict for Linear4bit:
#     Unexpected key(s) in state_dict: "weight.absmax", "weight.quant_map",
#     "weight.nested_absmax", "weight.nested_quant_map",
#     "weight.quant_state.bitsandbytes__nf4"

With strict=False, the QuantState components are silently ignored: dst.weight.data becomes the packed bytes from src, but dst.weight.quant_state is the one populated during dst.to('cuda') — not the one from sd. If src and dst quantized with different inputs (very common, e.g. loading a checkpoint into a freshly-initialized module), dst is left in a silently-corrupt state: packed bytes from one model, QuantState scalars from another.

Why this hasn't been reported as a user-facing bug

Standard workflows bypass model.load_state_dict() for bnb-quantized checkpoints:

  1. HF Transformers from_pretrained for pre-quantized bnb checkpoints: uses bnb.nn.Params4bit.from_prequantized(data, quantized_stats, device=...) directly via transformers/quantizers/quantizer_bnb_4bit.py::create_quantized_param. The quantized_stats dict is built by walking state_dict keys with prefix param_name + ".". Bypasses _load_from_state_dict entirely.
  2. PEFT save/load: get_peft_model_state_dict filters to LoRA-only keys ("lora_" in k); QuantState keys never enter the PEFT round-trip.
  3. Accelerate's fsdp2_load_full_state_dict: has its own broadcast + assign=True path; PR #3982 added key-based matching that filters non-Params keys.

Paths that do hit the asymmetry:

  • torch.distributed.checkpoint (DCP) when the loaded state_dict contains QuantState keys
  • Custom training loops that save via model.state_dict() and resume via model.load_state_dict()
  • The docstring example at nn/modules.py:522 (quantized_model.load_state_dict(fp16_model.state_dict())) works only because the source is fp16 (no QuantState keys to be unexpected) — masking the bnb→bnb round-trip case

Discovery context

Surfaced while characterizing FSDP2 + Params4bit forward correctness on Jetson Orin Nano Super (sm_87) — the same investigation that led to #1945. The investigation needed to round-trip QuantState through state_dict() for a pre-shard-broadcast pattern; that worked on the save side and failed on the load side, leading to this finding. The two issues are unrelated mechanically — this one is a missing override, #1945 is a PyTorch FSDP2 NaN-canonicalization bug — but they share discovery provenance.

Suggested fix

Implement _load_from_state_dict paralleling Linear8bitLt's pattern. Sketch:

def _load_from_state_dict(
    self, state_dict, prefix, local_metadata, strict,
    missing_keys, unexpected_keys, error_msgs,
):
    # Collect QuantState components in state_dict for this prefix
    qs_keys_to_consume = []
    quantized_stats = {}
    weight_dot_prefix = prefix + "weight."
    for k in list(state_dict.keys()):
        if k.startswith(weight_dot_prefix):
            qs_keys_to_consume.append(k)
            # store as `weight.<subkey>` for from_prequantized
            quantized_stats[k[len(prefix):]] = state_dict[k]

    # Standard nn.Linear path consumes 'weight' and 'bias'
    super()._load_from_state_dict(
        state_dict, prefix, local_metadata, strict,
        missing_keys, unexpected_keys, error_msgs,
    )

    # If QuantState components are present, reconstruct via from_prequantized
    if quantized_stats:
        for k in qs_keys_to_consume:
            if k in unexpected_keys:
                unexpected_keys.remove(k)

        weight_data = self.weight.data  # already loaded by super()
        self.weight = Params4bit.from_prequantized(
            data=weight_data,
            quantized_stats=quantized_stats,
            requires_grad=False,
            device=weight_data.device,
            module=self,
        )

Edge cases:

  • If QuantState keys are absent (e.g., the existing fp16→bnb docstring example), quantized_stats is empty and the function falls through to standard nn.Linear behavior — no regression.
  • The weight_dot_prefix match cleanly delimits per-Linear keys; the state_dict passed to _load_from_state_dict is already filtered to the current module's prefix by nn.Module.load_state_dict()'s recursion.

Test plan (for the fix)

  • Bnb→bnb round-trip: m1.state_dict() → m2.load_state_dict(strict=True) on CPU then CUDA Linear4bit; assert bnb.matmul_4bit produces identical output on synthetic input.
  • Existing fp16→bnb docstring example continues to pass (no regression).
  • Variant matrix: NF4 / FP4, compress_statistics on/off, quant_storage ∈ {uint8, bf16, fp16}, with/without bias.
  • strict=False behavior with partial QuantState keys (fall back to existing _quantize flow on .to(device) if applicable).

Severity

Latent — currently masked. Does not surface in standard HF + PEFT workflows because they bypass load_state_dict. But it's a contract violation between _save_to_state_dict and the standard nn.Module load mechanism, and surfaces for:

  • torch.distributed.checkpoint (DCP) usage
  • Custom training-loop checkpoint code
  • Any tooling that round-trips through model.state_dict() + model.load_state_dict()

Adjacent PRs (for context, do not address this)

Neither addresses the missing _load_from_state_dict.

Happy to file a PR with the fix sketch + tests if useful.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions