Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions bitsandbytes/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def _(
shape: Sequence[int],
dtype: torch.dtype,
) -> torch.Tensor:
torch._check_is_size(blocksize)
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")
return torch.empty(shape, dtype=dtype, device=A.device)


Expand All @@ -203,7 +203,7 @@ def _(
dtype: torch.dtype,
out: torch.Tensor,
) -> None:
torch._check_is_size(blocksize)
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")
torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}")
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
Expand All @@ -219,7 +219,7 @@ def _(
def _(
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
) -> tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")

n = A.numel()
blocks = -(n // -blocksize)
Expand All @@ -236,7 +236,7 @@ def _(

@register_fake("bitsandbytes::dequantize_blockwise")
def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor:
torch._check_is_size(blocksize)
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
return torch.empty_like(A, dtype=dtype)

Expand All @@ -251,7 +251,7 @@ def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int,
def _(
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor
):
torch._check_is_size(blocksize)
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}")
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
Expand All @@ -263,7 +263,7 @@ def _(

@register_fake("bitsandbytes::quantize_blockwise")
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")
n = A.numel()
blocks = -(n // -blocksize)
absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
Expand All @@ -281,7 +281,7 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor
def _(
A: torch.Tensor, B: torch.Tensor, shapeB: Sequence[int], absmax: torch.Tensor, code: torch.Tensor, blocksize: int
) -> torch.Tensor:
torch._check_is_size(blocksize)
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")
torch._check(A.numel() == A.size(-1), lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}")
torch._check(
A.dtype in [torch.float16, torch.bfloat16, torch.float32],
Expand Down Expand Up @@ -311,7 +311,7 @@ def _(
blocksize: int,
out: torch.Tensor,
) -> None:
torch._check_is_size(blocksize)
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")
torch._check(A.numel() == A.size(-1), lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}")
torch._check(
A.dtype in [torch.float16, torch.bfloat16, torch.float32],
Expand Down
6 changes: 3 additions & 3 deletions bitsandbytes/backends/cpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def _(A: torch.Tensor, B: torch.Tensor):

@register_kernel("bitsandbytes::quantize_blockwise", "cpu")
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")

n = A.numel()
blocks = -(n // -blocksize)
Expand Down Expand Up @@ -94,7 +94,7 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor
def _(
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype
) -> torch.Tensor:
torch._check_is_size(blocksize)
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")

out = torch.empty_like(A, dtype=dtype)
Expand Down Expand Up @@ -146,7 +146,7 @@ def _(
shape: Sequence[int],
dtype: torch.dtype,
) -> torch.Tensor:
torch._check_is_size(blocksize)
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")
torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be nf4 or fp4, got {quant_type}")
torch._check(
dtype in [torch.bfloat16, torch.float16, torch.float32],
Expand Down
4 changes: 2 additions & 2 deletions bitsandbytes/backends/cuda/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def _get_col_absmax(
@register_kernel("bitsandbytes::quantize_blockwise", "cuda")
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
A = A.contiguous()
torch._check_is_size(blocksize)
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")

torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32])

Expand Down Expand Up @@ -464,7 +464,7 @@ def _gemv_4bit_impl(
blocksize: int,
out: torch.Tensor,
) -> None:
torch._check_is_size(blocksize)
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")

# Note: these checks are not strictly necessary, and cost more than they are worth, so they are commented out for now.
# torch._check(
Expand Down
8 changes: 4 additions & 4 deletions bitsandbytes/backends/default/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def _(A: torch.Tensor, threshold=0.0):

@register_kernel("bitsandbytes::quantize_blockwise", "default")
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")

n = A.numel()
rem = n % blocksize
Expand All @@ -201,7 +201,7 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor

@register_kernel("bitsandbytes::dequantize_blockwise", "default")
def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor:
torch._check_is_size(blocksize)
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")

out = code[A.reshape(-1).int()]
Expand All @@ -220,7 +220,7 @@ def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int,
def _(
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
) -> tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")
torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be nf4 or fp4, got {quant_type}")
torch._check(
A.dtype in [torch.bfloat16, torch.float16, torch.float32],
Expand Down Expand Up @@ -317,7 +317,7 @@ def _(
shape: Sequence[int],
dtype: torch.dtype,
) -> torch.Tensor:
torch._check_is_size(blocksize)
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")
torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be nf4 or fp4, got {quant_type}")
torch._check(
dtype in [torch.bfloat16, torch.float16, torch.float32],
Expand Down
2 changes: 1 addition & 1 deletion bitsandbytes/backends/hpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def _(
shape: Sequence[int],
dtype: torch.dtype,
) -> torch.Tensor:
torch._check_is_size(blocksize)
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")
torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4, got {quant_type}")
torch._check(
A.dtype in [torch.bfloat16, torch.uint8],
Expand Down
10 changes: 5 additions & 5 deletions bitsandbytes/backends/triton/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@


def quantize_blockwise(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")
# torch._check(A.dtype == torch.float32, lambda: f"A must be float32 on xpu, got {A.dtype}")
with torch_accelerator_module.device(A.device):
out, absmax = kernels_8bit_quant.quantize_blockwise_triton(A.contiguous(), code, blocksize)
Expand All @@ -25,7 +25,7 @@ def quantize_blockwise(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> t
def dequantize_blockwise(
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype
) -> torch.Tensor:
torch._check_is_size(blocksize)
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
# torch._check(dtype == torch.float32, lambda: f"dtype must be float32 on xpu, got {dtype}")
with torch_accelerator_module.device(A.device):
Expand All @@ -47,7 +47,7 @@ def dequantize_blockwise_inplace(
dtype: torch.dtype,
out: torch.Tensor,
) -> None:
torch._check_is_size(blocksize)
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}")
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
Expand All @@ -67,7 +67,7 @@ def dequantize_blockwise_inplace(
def quantize_4bit(
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
) -> tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")
# torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on CPU, got {quant_type}")
torch._check(
A.dtype in [torch.bfloat16, torch.float16, torch.float32],
Expand Down Expand Up @@ -109,7 +109,7 @@ def dequantize_4bit(
shape: Sequence[int],
dtype: torch.dtype,
) -> torch.Tensor:
torch._check_is_size(blocksize)
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")
# torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on XPU, got {quant_type}")
torch._check(
dtype in [torch.bfloat16, torch.float16, torch.float32],
Expand Down
Loading