Skip to content

Commit 2762cd5

Browse files
committed
Remove repetitive 'var': param['var'] patterns
1 parent 5c6171d commit 2762cd5

11 files changed

Lines changed: 79 additions & 148 deletions

tests/networks/blocks/test_crossattention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
TEST_CASE_CABLOCK = [
3030
[
3131
{
32-
**{param: value for param, value in params.items() if param != "rel_pos_embedding_val"},
32+
**{k: v for k, v in params.items() if k not in ["rel_pos_embedding_val"]},
3333
"rel_pos_embedding": params["rel_pos_embedding_val"] if not params["use_flash_attention"] else None,
3434
},
3535
(2, 512, params["hidden_size"]),

tests/networks/blocks/test_transformerblock.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,7 @@
2525

2626
einops, has_einops = optional_import("einops")
2727
TEST_CASE_TRANSFORMERBLOCK = [
28-
[
29-
params,
30-
(2, 512, params["hidden_size"]),
31-
(2, 512, params["hidden_size"]),
32-
]
28+
[params, (2, 512, params["hidden_size"]), (2, 512, params["hidden_size"])]
3329
for params in dict_product(
3430
dropout_rate=np.linspace(0, 1, 4),
3531
hidden_size=[360, 480, 600, 768],

tests/networks/blocks/test_unetr_block.py

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@ def _get_out_size(params):
3333

3434

3535
norm_names = [("GROUP", {"num_groups": 16}), ("batch", {"track_running_stats": False}), "instance"]
36-
param_dicts = dict_product(spatial_dims=range(1, 4), kernel_size=[1, 3], stride=[2], norm_name=norm_names, in_size=[15, 16])
36+
param_dicts = dict_product(
37+
spatial_dims=range(1, 4), kernel_size=[1, 3], stride=[2], norm_name=norm_names, in_size=[15, 16]
38+
)
3739
TEST_CASE_UNETR_BASIC_BLOCK = []
3840
for params in param_dicts:
3941
input_param = {**{k: v for k, v in params.items() if k != "in_size"}, "in_channels": 16, "out_channels": 16}
@@ -45,12 +47,7 @@ def _get_out_size(params):
4547
TEST_UP_BLOCK = [
4648
[
4749
{
48-
"spatial_dims": params["spatial_dims"],
49-
"in_channels": params["in_channels"],
50-
"out_channels": params["out_channels"],
51-
"kernel_size": params["kernel_size"],
52-
"norm_name": params["norm_name"],
53-
"res_block": params["res_block"],
50+
**{k: v for k, v in params.items() if k not in ["in_size", "stride", "upsample_kernel_size"]},
5451
"upsample_kernel_size": params["stride"],
5552
},
5653
(1, params["in_channels"], *([params["in_size"]] * params["spatial_dims"])),
@@ -82,31 +79,20 @@ def _get_out_size(params):
8279
in_size_scalar=[15, 16],
8380
num_layer=[0, 2],
8481
):
85-
spatial_dims_val = params["spatial_dims"]
86-
in_size_val = params["in_size_scalar"]
87-
upsample_kernel_size_val = params["upsample_kernel_size"]
88-
num_layer_val = params["num_layer"]
89-
90-
in_size_tmp = in_size_val
82+
in_size_tmp = params["in_size_scalar"]
9183
out_size = 0 # Initialize out_size
92-
for _ in range(num_layer_val + 1):
93-
out_size = in_size_tmp * upsample_kernel_size_val
84+
for _ in range(params["num_layer"] + 1):
85+
out_size = in_size_tmp * params["upsample_kernel_size"]
9486
in_size_tmp = out_size
9587

9688
test_case = [
9789
{
98-
"spatial_dims": spatial_dims_val,
90+
**{k: v for k, v in params.items() if k != "in_size_scalar"},
9991
"in_channels": in_channels,
10092
"out_channels": out_channels,
101-
"num_layer": num_layer_val,
102-
"kernel_size": params["kernel_size"],
103-
"norm_name": params["norm_name"],
104-
"stride": params["stride"],
105-
"res_block": params["res_block"],
106-
"upsample_kernel_size": upsample_kernel_size_val,
10793
},
108-
(1, in_channels, *([in_size_val] * spatial_dims_val)),
109-
(1, out_channels, *([out_size] * spatial_dims_val)),
94+
(1, in_channels, *([params["in_size_scalar"]] * params["spatial_dims"])),
95+
(1, out_channels, *([out_size] * params["spatial_dims"])),
11096
]
11197
TEST_PRUP_BLOCK.append(test_case)
11298

tests/networks/nets/test_mednext.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,7 @@
2424

2525
TEST_CASE_MEDNEXT = [
2626
[
27-
{
28-
"spatial_dims": params["spatial_dims"],
29-
"init_filters": params["init_filters"],
30-
"deep_supervision": params["deep_supervision"],
31-
"use_residual_connection": params["do_res"],
32-
},
27+
{**{k: v for k, v in params.items() if k != "do_res"}, "use_residual_connection": params["do_res"]},
3328
(2, 1, *([16] * params["spatial_dims"])),
3429
(2, 2, *([16] * params["spatial_dims"])),
3530
]
@@ -38,16 +33,7 @@
3833
)
3934
]
4035
TEST_CASE_MEDNEXT_2 = [
41-
[
42-
{
43-
"spatial_dims": params["spatial_dims"],
44-
"init_filters": params["init_filters"],
45-
"out_channels": params["out_channels"],
46-
"deep_supervision": params["deep_supervision"],
47-
},
48-
(2, 1, *([16] * params["spatial_dims"])),
49-
(2, params["out_channels"], *([16] * params["spatial_dims"])),
50-
]
36+
[params, (2, 1, *([16] * params["spatial_dims"])), (2, params["out_channels"], *([16] * params["spatial_dims"]))]
5137
for params in dict_product(
5238
spatial_dims=range(2, 4), out_channels=[1, 2], deep_supervision=[False, True], init_filters=[8]
5339
)

tests/networks/nets/test_segresnet.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,7 @@
2525

2626
TEST_CASE_SEGRESNET = [
2727
[
28-
{
29-
"spatial_dims": params["spatial_dims"],
30-
"init_filters": params["init_filters"],
31-
"dropout_prob": params["dropout_prob"],
32-
"norm": params["norm"],
33-
"upsample_mode": params["upsample_mode"],
34-
"use_conv_final": False,
35-
},
28+
{**params, "use_conv_final": False},
3629
(2, 1, *([16] * params["spatial_dims"])),
3730
(2, params["init_filters"], *([16] * params["spatial_dims"])),
3831
]
@@ -47,9 +40,7 @@
4740

4841
TEST_CASE_SEGRESNET_2 = [
4942
[
50-
{
51-
**params,
52-
},
43+
{**params},
5344
(2, 1, *([16] * params["spatial_dims"])),
5445
(2, params["out_channels"], *([16] * params["spatial_dims"])),
5546
]

tests/networks/nets/test_segresnet_ds.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,7 @@
2323
device = "cuda" if torch.cuda.is_available() else "cpu"
2424

2525
TEST_CASE_SEGRESNET_DS = [
26-
[
27-
{
28-
"spatial_dims": params["spatial_dims"],
29-
"init_filters": params["init_filters"],
30-
"act": params["act"],
31-
"norm": params["norm"],
32-
"upsample_mode": params["upsample_mode"],
33-
},
34-
(2, 1, *([16] * params["spatial_dims"])),
35-
(2, 2, *([16] * params["spatial_dims"])),
36-
]
26+
[params, (2, 1, *([16] * params["spatial_dims"])), (2, 2, *([16] * params["spatial_dims"]))]
3727
for params in dict_product(
3828
spatial_dims=range(2, 4),
3929
init_filters=[8, 16],
@@ -45,12 +35,7 @@
4535

4636
TEST_CASE_SEGRESNET_DS2 = [
4737
[
48-
{
49-
"spatial_dims": params["spatial_dims"],
50-
"init_filters": 8,
51-
"out_channels": params["out_channels"],
52-
"dsdepth": params["dsdepth"],
53-
},
38+
{**params, "init_filters": 8},
5439
(2, 1, *([16] * params["spatial_dims"])),
5540
(2, params["out_channels"], *([16] * params["spatial_dims"])),
5641
]

tests/networks/nets/test_swin_unetr.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,28 +41,22 @@
4141
TEST_CASE_SWIN_UNETR = [
4242
[
4343
{
44+
**{k: v for k, v in params.items() if k != "img_size"},
4445
"spatial_dims": len(params["img_size"]),
45-
"in_channels": params["in_channels"],
46-
"out_channels": params["out_channels"],
47-
"feature_size": params["feature_size"],
48-
"depths": params["depth"],
49-
"norm_name": params["norm_name"],
50-
"attn_drop_rate": params["attn_drop_rate"],
5146
"downsample": test_merging_mode[i % len(test_merging_mode)],
52-
"use_checkpoint": params["use_checkpoint"],
5347
},
5448
(2, params["in_channels"], *params["img_size"]),
5549
(2, params["out_channels"], *params["img_size"]),
5650
]
5751
for i, params in enumerate(
5852
dict_product(
5953
attn_drop_rate=[0.4],
60-
in_channels=[1],
61-
depth=[[2, 1, 1, 1], [1, 2, 1, 1]],
62-
out_channels=[2],
63-
img_size=((64, 32, 192), (96, 32)),
54+
depths=[[2, 1, 1, 1], [1, 2, 1, 1]],
6455
feature_size=[12],
56+
img_size=((64, 32, 192), (96, 32)),
57+
in_channels=[1],
6558
norm_name=["instance"],
59+
out_channels=[2],
6660
use_checkpoint=checkpoint_vals,
6761
)
6862
)

tests/networks/nets/test_transchex.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,26 +23,21 @@
2323
TEST_CASE_TRANSCHEX = [
2424
[
2525
{
26-
"in_channels": params["in_channels"],
26+
**{k: v for k, v in params.items() if k != "img_size"},
2727
"img_size": (params["img_size"],) * 2,
2828
"patch_size": (params["patch_size"],) * 2,
29-
"num_vision_layers": params["num_vision_layers"],
30-
"num_mixed_layers": params["num_mixed_layers"],
31-
"num_language_layers": params["num_language_layers"],
32-
"num_classes": params["num_classes"],
33-
"drop_out": params["drop_out"],
3429
},
3530
(2, params["num_classes"]),
3631
]
3732
for params in dict_product(
3833
drop_out=[0.4],
39-
in_channels=[3],
4034
img_size=[224],
41-
patch_size=[16, 32],
35+
in_channels=[3],
36+
num_classes=[8],
4237
num_language_layers=[2],
43-
num_vision_layers=[4],
4438
num_mixed_layers=[3],
45-
num_classes=[8],
39+
num_vision_layers=[4],
40+
patch_size=[16, 32],
4641
)
4742
]
4843

tests/networks/nets/test_unetr.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,35 +23,27 @@
2323
TEST_CASE_UNETR = [
2424
[
2525
{
26-
"in_channels": params["in_channels"],
27-
"out_channels": params["out_channels"],
28-
"img_size": (params["img_size"],) * params["nd"],
29-
"hidden_size": params["hidden_size"],
30-
"feature_size": params["feature_size"],
31-
"norm_name": params["norm_name"],
32-
"mlp_dim": params["mlp_dim"],
33-
"num_heads": params["num_heads"],
34-
"proj_type": params["proj_type"],
35-
"dropout_rate": params["dropout_rate"],
26+
**{k: v for k, v in params.items() if k not in ["img_size", "nd"]},
3627
"conv_block": True,
3728
"res_block": False,
29+
"img_size": (params["img_size"],) * params["nd"],
3830
**({"spatial_dims": 2} if params["nd"] == 2 else {}),
3931
},
4032
(2, params["in_channels"], *([params["img_size"]] * params["nd"])),
4133
(2, params["out_channels"], *([params["img_size"]] * params["nd"])),
4234
]
4335
for params in dict_product(
4436
dropout_rate=[0.4],
45-
in_channels=[1],
46-
out_channels=[2],
37+
feature_size=[16],
4738
hidden_size=[768],
4839
img_size=[96, 128],
49-
feature_size=[16],
50-
num_heads=[8],
40+
in_channels=[1],
5141
mlp_dim=[3072],
42+
nd=[2, 3],
5243
norm_name=["instance"],
44+
num_heads=[8],
45+
out_channels=[2],
5346
proj_type=["perceptron"],
54-
nd=[2, 3],
5547
)
5648
]
5749

tests/networks/nets/test_vit.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,21 @@
2121
from tests.test_utils import SkipIfBeforePyTorchVersion, dict_product, skip_if_quick, test_script_save
2222

2323
TEST_CASE_Vit = [
24-
([
25-
{
26-
**{k: v for k, v in params.items() if k not in ["nd"]},
27-
**({"spatial_dims": 2} if params["nd"] == 2 else {}),
28-
**({"post_activation": False} if params["nd"] == 2 and params["classification"] else {}),
29-
},
30-
(2, params["in_channels"], *([params["img_size"]] * params["nd"])),
31-
(
32-
(2, params["num_classes"])
33-
if params["classification"]
34-
else (2, (params["img_size"] // params["patch_size"]) ** params["nd"], params["hidden_size"])
35-
),
36-
])
24+
(
25+
[
26+
{
27+
**{k: v for k, v in params.items() if k not in ["nd"]},
28+
**({"spatial_dims": 2} if params["nd"] == 2 else {}),
29+
**({"post_activation": False} if params["nd"] == 2 and params["classification"] else {}),
30+
},
31+
(2, params["in_channels"], *([params["img_size"]] * params["nd"])),
32+
(
33+
(2, params["num_classes"])
34+
if params["classification"]
35+
else (2, (params["img_size"] // params["patch_size"]) ** params["nd"], params["hidden_size"])
36+
),
37+
]
38+
)
3739
for params in dict_product(
3840
dropout_rate=[0.6],
3941
in_channels=[4],
@@ -60,13 +62,15 @@ def test_shape(self, input_param, input_shape, expected_shape):
6062
result, _ = net(torch.randn(input_shape))
6163
self.assertEqual(result.shape, expected_shape)
6264

63-
@parameterized.expand([
64-
(1, (128, 128, 128), (16, 16, 16), 128, 3072, 12, 12, "conv", False, 5.0),
65-
(1, (32, 32, 32), (64, 64, 64), 512, 3072, 12, 8, "perceptron", False, 0.3),
66-
(1, (96, 96, 96), (8, 8, 8), 512, 3072, 12, 14, "conv", False, 0.3),
67-
(1, (97, 97, 97), (4, 4, 4), 768, 3072, 12, 8, "perceptron", True, 0.3),
68-
(4, (96, 96, 96), (16, 16, 16), 768, 3072, 12, 12, "perc", False, 0.3),
69-
])
65+
@parameterized.expand(
66+
[
67+
(1, (128, 128, 128), (16, 16, 16), 128, 3072, 12, 12, "conv", False, 5.0),
68+
(1, (32, 32, 32), (64, 64, 64), 512, 3072, 12, 8, "perceptron", False, 0.3),
69+
(1, (96, 96, 96), (8, 8, 8), 512, 3072, 12, 14, "conv", False, 0.3),
70+
(1, (97, 97, 97), (4, 4, 4), 768, 3072, 12, 8, "perceptron", True, 0.3),
71+
(4, (96, 96, 96), (16, 16, 16), 768, 3072, 12, 12, "perc", False, 0.3),
72+
]
73+
)
7074
def test_ill_arg(
7175
self,
7276
in_channels,

0 commit comments

Comments
 (0)