diff --git a/bitsandbytes/_ops.py b/bitsandbytes/_ops.py index 532fe7afa..87e7c0485 100644 --- a/bitsandbytes/_ops.py +++ b/bitsandbytes/_ops.py @@ -183,7 +183,7 @@ def _( shape: Sequence[int], dtype: torch.dtype, ) -> torch.Tensor: - torch._check_is_size(blocksize) + torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}") return torch.empty(shape, dtype=dtype, device=A.device) @@ -203,7 +203,7 @@ def _( dtype: torch.dtype, out: torch.Tensor, ) -> None: - torch._check_is_size(blocksize) + torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}") torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}") torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}") torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") @@ -219,7 +219,7 @@ def _( def _( A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype ) -> tuple[torch.Tensor, torch.Tensor]: - torch._check_is_size(blocksize) + torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}") n = A.numel() blocks = -(n // -blocksize) @@ -236,7 +236,7 @@ def _( @register_fake("bitsandbytes::dequantize_blockwise") def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor: - torch._check_is_size(blocksize) + torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}") torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") return torch.empty_like(A, dtype=dtype) @@ -251,7 +251,7 @@ def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, def _( A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor ): - torch._check_is_size(blocksize) + torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}") torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}") torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}") @@ -263,7 +263,7 @@ def _( @register_fake("bitsandbytes::quantize_blockwise") def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: - torch._check_is_size(blocksize) + torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}") n = A.numel() blocks = -(n // -blocksize) absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) @@ -281,7 +281,7 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor def _( A: torch.Tensor, B: torch.Tensor, shapeB: Sequence[int], absmax: torch.Tensor, code: torch.Tensor, blocksize: int ) -> torch.Tensor: - torch._check_is_size(blocksize) + torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}") torch._check(A.numel() == A.size(-1), lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}") torch._check( A.dtype in [torch.float16, torch.bfloat16, torch.float32], @@ -311,7 +311,7 @@ def _( blocksize: int, out: torch.Tensor, ) -> None: - torch._check_is_size(blocksize) + torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}") torch._check(A.numel() == A.size(-1), lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}") torch._check( A.dtype in [torch.float16, torch.bfloat16, torch.float32], diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index 6c93a3cd0..6b82c2421 100755 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -35,7 +35,7 @@ def _(A: torch.Tensor, B: torch.Tensor): @register_kernel("bitsandbytes::quantize_blockwise", "cpu") def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: - torch._check_is_size(blocksize) + torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}") n = A.numel() blocks = -(n // -blocksize) @@ -94,7 +94,7 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor def _( A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype ) -> torch.Tensor: - torch._check_is_size(blocksize) + torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}") torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") out = torch.empty_like(A, dtype=dtype) @@ -146,7 +146,7 @@ def _( shape: Sequence[int], dtype: torch.dtype, ) -> torch.Tensor: - torch._check_is_size(blocksize) + torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}") torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be nf4 or fp4, got {quant_type}") torch._check( dtype in [torch.bfloat16, torch.float16, torch.float32], diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index 4e9d1222a..409e0252d 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -221,7 +221,7 @@ def _get_col_absmax( @register_kernel("bitsandbytes::quantize_blockwise", "cuda") def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: A = A.contiguous() - torch._check_is_size(blocksize) + torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}") torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32]) @@ -464,7 +464,7 @@ def _gemv_4bit_impl( blocksize: int, out: torch.Tensor, ) -> None: - torch._check_is_size(blocksize) + torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}") # Note: these checks are not strictly necessary, and cost more than they are worth, so they are commented out for now. # torch._check( diff --git a/bitsandbytes/backends/default/ops.py b/bitsandbytes/backends/default/ops.py index 74eeb5502..6f5eecdf2 100644 --- a/bitsandbytes/backends/default/ops.py +++ b/bitsandbytes/backends/default/ops.py @@ -175,7 +175,7 @@ def _(A: torch.Tensor, threshold=0.0): @register_kernel("bitsandbytes::quantize_blockwise", "default") def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: - torch._check_is_size(blocksize) + torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}") n = A.numel() rem = n % blocksize @@ -201,7 +201,7 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor @register_kernel("bitsandbytes::dequantize_blockwise", "default") def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor: - torch._check_is_size(blocksize) + torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}") torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") out = code[A.reshape(-1).int()] @@ -220,7 +220,7 @@ def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, def _( A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype ) -> tuple[torch.Tensor, torch.Tensor]: - torch._check_is_size(blocksize) + torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}") torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be nf4 or fp4, got {quant_type}") torch._check( A.dtype in [torch.bfloat16, torch.float16, torch.float32], @@ -317,7 +317,7 @@ def _( shape: Sequence[int], dtype: torch.dtype, ) -> torch.Tensor: - torch._check_is_size(blocksize) + torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}") torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be nf4 or fp4, got {quant_type}") torch._check( dtype in [torch.bfloat16, torch.float16, torch.float32], diff --git a/bitsandbytes/backends/hpu/ops.py b/bitsandbytes/backends/hpu/ops.py index 9ecd63e0b..2844df731 100644 --- a/bitsandbytes/backends/hpu/ops.py +++ b/bitsandbytes/backends/hpu/ops.py @@ -25,7 +25,7 @@ def _( shape: Sequence[int], dtype: torch.dtype, ) -> torch.Tensor: - torch._check_is_size(blocksize) + torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}") torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4, got {quant_type}") torch._check( A.dtype in [torch.bfloat16, torch.uint8], diff --git a/bitsandbytes/backends/triton/ops.py b/bitsandbytes/backends/triton/ops.py index 686d42929..6b1a2904b 100644 --- a/bitsandbytes/backends/triton/ops.py +++ b/bitsandbytes/backends/triton/ops.py @@ -15,7 +15,7 @@ def quantize_blockwise(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: - torch._check_is_size(blocksize) + torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}") # torch._check(A.dtype == torch.float32, lambda: f"A must be float32 on xpu, got {A.dtype}") with torch_accelerator_module.device(A.device): out, absmax = kernels_8bit_quant.quantize_blockwise_triton(A.contiguous(), code, blocksize) @@ -25,7 +25,7 @@ def quantize_blockwise(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> t def dequantize_blockwise( A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype ) -> torch.Tensor: - torch._check_is_size(blocksize) + torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}") torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") # torch._check(dtype == torch.float32, lambda: f"dtype must be float32 on xpu, got {dtype}") with torch_accelerator_module.device(A.device): @@ -47,7 +47,7 @@ def dequantize_blockwise_inplace( dtype: torch.dtype, out: torch.Tensor, ) -> None: - torch._check_is_size(blocksize) + torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}") torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}") torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}") @@ -67,7 +67,7 @@ def dequantize_blockwise_inplace( def quantize_4bit( A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype ) -> tuple[torch.Tensor, torch.Tensor]: - torch._check_is_size(blocksize) + torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}") # torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on CPU, got {quant_type}") torch._check( A.dtype in [torch.bfloat16, torch.float16, torch.float32], @@ -109,7 +109,7 @@ def dequantize_4bit( shape: Sequence[int], dtype: torch.dtype, ) -> torch.Tensor: - torch._check_is_size(blocksize) + torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}") # torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on XPU, got {quant_type}") torch._check( dtype in [torch.bfloat16, torch.float16, torch.float32],