@@ -359,15 +359,7 @@ def _test_dequantize(self, config_kwargs):
359359 if isinstance (module , torch .nn .Linear ):
360360 assert not self ._is_module_quantized (module ), f"Module { name } is still quantized after dequantize()"
361361
362- # Get model dtype from first parameter
363- model_dtype = next (model .parameters ()).dtype
364-
365362 inputs = self .get_dummy_inputs ()
366- # Cast inputs to model dtype
367- inputs = {
368- k : v .to (model_dtype ) if isinstance (v , torch .Tensor ) and v .is_floating_point () else v
369- for k , v in inputs .items ()
370- }
371363 output = model (** inputs , return_dict = False )[0 ]
372364 assert output is not None , "Model output is None after dequantization"
373365 assert not torch .isnan (output ).any (), "Model output contains NaN after dequantization"
@@ -575,33 +567,28 @@ def test_bnb_original_dtype(self):
575567
576568 @torch .no_grad ()
577569 def test_bnb_keep_modules_in_fp32 (self ):
578- if not hasattr (self .model_class , "_keep_in_fp32_modules" ):
579- pytest .skip (f"{ self .model_class .__name__ } does not have _keep_in_fp32_modules" )
570+ fp32_modules = getattr (self .model_class , "_keep_in_fp32_modules" , None )
571+ if not fp32_modules :
572+ pytest .skip (f"{ self .model_class .__name__ } does not declare _keep_in_fp32_modules" )
580573
581574 config_kwargs = BitsAndBytesConfigMixin .BNB_CONFIGS ["4bit_nf4" ]
582575
583- original_fp32_modules = getattr (self .model_class , "_keep_in_fp32_modules" , None )
584- self .model_class ._keep_in_fp32_modules = ["proj_out" ]
585-
586- try :
587- model = self ._create_quantized_model (config_kwargs )
576+ model = self ._create_quantized_model (config_kwargs )
577+ model .to (torch_device )
588578
589- for name , module in model .named_modules ():
590- if isinstance (module , torch .nn .Linear ):
591- if any (fp32_name in name for fp32_name in model . _keep_in_fp32_modules ):
592- assert module .weight .dtype == torch .float32 , (
593- f"Module { name } should be FP32 but is { module .weight .dtype } "
594- )
595- else :
596- assert module .weight .dtype == torch .uint8 , (
597- f"Module { name } should be uint8 but is { module .weight .dtype } "
598- )
579+ for name , module in model .named_modules ():
580+ if isinstance (module , torch .nn .Linear ):
581+ if any (fp32_name in name for fp32_name in fp32_modules ):
582+ assert module .weight .dtype == torch .float32 , (
583+ f"Module { name } should be FP32 but is { module .weight .dtype } "
584+ )
585+ else :
586+ assert module .weight .dtype == torch .uint8 , (
587+ f"Module { name } should be uint8 but is { module .weight .dtype } "
588+ )
599589
600- inputs = self .get_dummy_inputs ()
601- _ = model (** inputs )
602- finally :
603- if original_fp32_modules is not None :
604- self .model_class ._keep_in_fp32_modules = original_fp32_modules
590+ inputs = self .get_dummy_inputs ()
591+ _ = model (** inputs )
605592
606593 def test_bnb_modules_to_not_convert (self ):
607594 """Test that modules_to_not_convert parameter works correctly."""
0 commit comments