2121from tests .test_utils import SkipIfBeforePyTorchVersion , dict_product , skip_if_quick , test_script_save
2222
2323TEST_CASE_Vit = [
24- ([
25- {
26- ** {k : v for k , v in params .items () if k not in ["nd" ]},
27- ** ({"spatial_dims" : 2 } if params ["nd" ] == 2 else {}),
28- ** ({"post_activation" : False } if params ["nd" ] == 2 and params ["classification" ] else {}),
29- },
30- (2 , params ["in_channels" ], * ([params ["img_size" ]] * params ["nd" ])),
31- (
32- (2 , params ["num_classes" ])
33- if params ["classification" ]
34- else (2 , (params ["img_size" ] // params ["patch_size" ]) ** params ["nd" ], params ["hidden_size" ])
35- ),
36- ])
24+ (
25+ [
26+ {
27+ ** {k : v for k , v in params .items () if k not in ["nd" ]},
28+ ** ({"spatial_dims" : 2 } if params ["nd" ] == 2 else {}),
29+ ** ({"post_activation" : False } if params ["nd" ] == 2 and params ["classification" ] else {}),
30+ },
31+ (2 , params ["in_channels" ], * ([params ["img_size" ]] * params ["nd" ])),
32+ (
33+ (2 , params ["num_classes" ])
34+ if params ["classification" ]
35+ else (2 , (params ["img_size" ] // params ["patch_size" ]) ** params ["nd" ], params ["hidden_size" ])
36+ ),
37+ ]
38+ )
3739 for params in dict_product (
3840 dropout_rate = [0.6 ],
3941 in_channels = [4 ],
@@ -60,13 +62,15 @@ def test_shape(self, input_param, input_shape, expected_shape):
6062 result , _ = net (torch .randn (input_shape ))
6163 self .assertEqual (result .shape , expected_shape )
6264
63- @parameterized .expand ([
64- (1 , (128 , 128 , 128 ), (16 , 16 , 16 ), 128 , 3072 , 12 , 12 , "conv" , False , 5.0 ),
65- (1 , (32 , 32 , 32 ), (64 , 64 , 64 ), 512 , 3072 , 12 , 8 , "perceptron" , False , 0.3 ),
66- (1 , (96 , 96 , 96 ), (8 , 8 , 8 ), 512 , 3072 , 12 , 14 , "conv" , False , 0.3 ),
67- (1 , (97 , 97 , 97 ), (4 , 4 , 4 ), 768 , 3072 , 12 , 8 , "perceptron" , True , 0.3 ),
68- (4 , (96 , 96 , 96 ), (16 , 16 , 16 ), 768 , 3072 , 12 , 12 , "perc" , False , 0.3 ),
69- ])
65+ @parameterized .expand (
66+ [
67+ (1 , (128 , 128 , 128 ), (16 , 16 , 16 ), 128 , 3072 , 12 , 12 , "conv" , False , 5.0 ),
68+ (1 , (32 , 32 , 32 ), (64 , 64 , 64 ), 512 , 3072 , 12 , 8 , "perceptron" , False , 0.3 ),
69+ (1 , (96 , 96 , 96 ), (8 , 8 , 8 ), 512 , 3072 , 12 , 14 , "conv" , False , 0.3 ),
70+ (1 , (97 , 97 , 97 ), (4 , 4 , 4 ), 768 , 3072 , 12 , 8 , "perceptron" , True , 0.3 ),
71+ (4 , (96 , 96 , 96 ), (16 , 16 , 16 ), 768 , 3072 , 12 , 12 , "perc" , False , 0.3 ),
72+ ]
73+ )
7074 def test_ill_arg (
7175 self ,
7276 in_channels ,
0 commit comments