2121from tests .test_utils import SkipIfBeforePyTorchVersion , dict_product , skip_if_quick , test_script_save
2222
2323TEST_CASE_Vit = [
24- (
25- [
26- {
27- "in_channels" : params ["in_channels" ],
28- "img_size" : (params ["img_size" ],) * params ["nd" ],
29- "patch_size" : (params ["patch_size" ],) * params ["nd" ],
30- "hidden_size" : params ["hidden_size" ],
31- "mlp_dim" : params ["mlp_dim" ],
32- "num_layers" : params ["num_layers" ],
33- "num_heads" : params ["num_heads" ],
34- "proj_type" : params ["proj_type" ],
35- "classification" : params ["classification" ],
36- "num_classes" : params ["num_classes" ],
37- "dropout_rate" : params ["dropout_rate" ],
38- ** ({"spatial_dims" : 2 } if params ["nd" ] == 2 else {}),
39- ** ({"post_activation" : False } if params ["nd" ] == 2 and params ["classification" ] else {}),
40- },
41- (2 , params ["in_channels" ], * ([params ["img_size" ]] * params ["nd" ])),
42- (
43- (2 , params ["num_classes" ])
44- if params ["classification" ]
45- else (2 , (params ["img_size" ] // params ["patch_size" ]) ** params ["nd" ], params ["hidden_size" ])
46- ),
47- ]
48- )
24+ ([
25+ {
26+ ** {k : v for k , v in params .items () if k not in ["nd" ]},
27+ ** ({"spatial_dims" : 2 } if params ["nd" ] == 2 else {}),
28+ ** ({"post_activation" : False } if params ["nd" ] == 2 and params ["classification" ] else {}),
29+ },
30+ (2 , params ["in_channels" ], * ([params ["img_size" ]] * params ["nd" ])),
31+ (
32+ (2 , params ["num_classes" ])
33+ if params ["classification" ]
34+ else (2 , (params ["img_size" ] // params ["patch_size" ]) ** params ["nd" ], params ["hidden_size" ])
35+ ),
36+ ])
4937 for params in dict_product (
5038 dropout_rate = [0.6 ],
5139 in_channels = [4 ],
@@ -72,15 +60,13 @@ def test_shape(self, input_param, input_shape, expected_shape):
7260 result , _ = net (torch .randn (input_shape ))
7361 self .assertEqual (result .shape , expected_shape )
7462
75- @parameterized .expand (
76- [
77- (1 , (128 , 128 , 128 ), (16 , 16 , 16 ), 128 , 3072 , 12 , 12 , "conv" , False , 5.0 ),
78- (1 , (32 , 32 , 32 ), (64 , 64 , 64 ), 512 , 3072 , 12 , 8 , "perceptron" , False , 0.3 ),
79- (1 , (96 , 96 , 96 ), (8 , 8 , 8 ), 512 , 3072 , 12 , 14 , "conv" , False , 0.3 ),
80- (1 , (97 , 97 , 97 ), (4 , 4 , 4 ), 768 , 3072 , 12 , 8 , "perceptron" , True , 0.3 ),
81- (4 , (96 , 96 , 96 ), (16 , 16 , 16 ), 768 , 3072 , 12 , 12 , "perc" , False , 0.3 ),
82- ]
83- )
63+ @parameterized .expand ([
64+ (1 , (128 , 128 , 128 ), (16 , 16 , 16 ), 128 , 3072 , 12 , 12 , "conv" , False , 5.0 ),
65+ (1 , (32 , 32 , 32 ), (64 , 64 , 64 ), 512 , 3072 , 12 , 8 , "perceptron" , False , 0.3 ),
66+ (1 , (96 , 96 , 96 ), (8 , 8 , 8 ), 512 , 3072 , 12 , 14 , "conv" , False , 0.3 ),
67+ (1 , (97 , 97 , 97 ), (4 , 4 , 4 ), 768 , 3072 , 12 , 8 , "perceptron" , True , 0.3 ),
68+ (4 , (96 , 96 , 96 ), (16 , 16 , 16 ), 768 , 3072 , 12 , 12 , "perc" , False , 0.3 ),
69+ ])
8470 def test_ill_arg (
8571 self ,
8672 in_channels ,
0 commit comments