Skip to content

Commit 5c6171d

Browse files
committed
Improve other tests
1 parent 69f5133 commit 5c6171d

7 files changed

Lines changed: 70 additions & 129 deletions

File tree

tests/networks/blocks/test_segresnet_block.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,7 @@
2222

2323
TEST_CASE_RESBLOCK = [
2424
[
25-
{
26-
"spatial_dims": params["spatial_dims"],
27-
"in_channels": params["in_channels"],
28-
"kernel_size": params["kernel_size"],
29-
"norm": params["norm"],
30-
},
25+
params,
3126
(2, params["in_channels"], *([16] * params["spatial_dims"])),
3227
(2, params["in_channels"], *([16] * params["spatial_dims"])),
3328
]
@@ -41,7 +36,6 @@
4136

4237

4338
class TestResBlock(unittest.TestCase):
44-
4539
@parameterized.expand(TEST_CASE_RESBLOCK)
4640
def test_shape(self, input_param, input_shape, expected_shape):
4741
net = ResBlock(**input_param)

tests/networks/blocks/test_transformerblock.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,7 @@
2626
einops, has_einops = optional_import("einops")
2727
TEST_CASE_TRANSFORMERBLOCK = [
2828
[
29-
{
30-
"hidden_size": params["hidden_size"],
31-
"num_heads": params["num_heads"],
32-
"mlp_dim": params["mlp_dim"],
33-
"dropout_rate": params["dropout_rate"],
34-
"with_cross_attention": params["with_cross_attention"],
35-
},
29+
params,
3630
(2, 512, params["hidden_size"]),
3731
(2, 512, params["hidden_size"]),
3832
]

tests/networks/blocks/test_unetr_block.py

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -32,27 +32,15 @@ def _get_out_size(params):
3232
return int((in_size + 2 * padding - kernel_size) / stride) + 1
3333

3434

35-
TEST_CASE_UNETR_BASIC_BLOCK = [
36-
[
37-
{
38-
"spatial_dims": params["spatial_dims"],
39-
"in_channels": 16,
40-
"out_channels": 16,
41-
"kernel_size": params["kernel_size"],
42-
"norm_name": params["norm_name"],
43-
"stride": params["stride"],
44-
},
45-
(1, 16, *([params["in_size"]] * params["spatial_dims"])),
46-
(1, 16, *([_get_out_size(params)] * params["spatial_dims"])),
47-
]
48-
for params in dict_product(
49-
spatial_dims=range(1, 4),
50-
kernel_size=[1, 3],
51-
stride=[2],
52-
norm_name=[("GROUP", {"num_groups": 16}), ("batch", {"track_running_stats": False}), "instance"],
53-
in_size=[15, 16],
54-
)
55-
]
35+
norm_names = [("GROUP", {"num_groups": 16}), ("batch", {"track_running_stats": False}), "instance"]
36+
param_dicts = dict_product(spatial_dims=range(1, 4), kernel_size=[1, 3], stride=[2], norm_name=norm_names, in_size=[15, 16])
37+
TEST_CASE_UNETR_BASIC_BLOCK = []
38+
for params in param_dicts:
39+
input_param = {**{k: v for k, v in params.items() if k != "in_size"}, "in_channels": 16, "out_channels": 16}
40+
input_shape = (1, 16, *([params["in_size"]] * params["spatial_dims"]))
41+
expected_shape = (1, 16, *([_get_out_size(params)] * params["spatial_dims"]))
42+
TEST_CASE_UNETR_BASIC_BLOCK.append([input_param, input_shape, expected_shape])
43+
5644

5745
TEST_UP_BLOCK = [
5846
[

tests/networks/nets/test_segresnet.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,7 @@
4848
TEST_CASE_SEGRESNET_2 = [
4949
[
5050
{
51-
"spatial_dims": params["spatial_dims"],
52-
"init_filters": params["init_filters"],
53-
"out_channels": params["out_channels"],
54-
"upsample_mode": params["upsample_mode"],
51+
**params,
5552
},
5653
(2, 1, *([16] * params["spatial_dims"])),
5754
(2, params["out_channels"], *([16] * params["spatial_dims"])),
@@ -64,13 +61,9 @@
6461
TEST_CASE_SEGRESNET_VAE = [
6562
[
6663
{
67-
"spatial_dims": params["spatial_dims"],
68-
"init_filters": params["init_filters"],
69-
"out_channels": params["out_channels"],
70-
"upsample_mode": params["upsample_mode"],
64+
**params,
7165
"act": ("leakyrelu", {"inplace": True, "negative_slope": 0.01}),
7266
"input_image_size": ([16] * params["spatial_dims"]),
73-
"vae_estimate_std": params["vae_estimate_std"],
7467
},
7568
(2, 1, *([16] * params["spatial_dims"])),
7669
(2, params["out_channels"], *([16] * params["spatial_dims"])),

tests/networks/nets/test_vit.py

Lines changed: 20 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -21,31 +21,19 @@
2121
from tests.test_utils import SkipIfBeforePyTorchVersion, dict_product, skip_if_quick, test_script_save
2222

2323
TEST_CASE_Vit = [
24-
(
25-
[
26-
{
27-
"in_channels": params["in_channels"],
28-
"img_size": (params["img_size"],) * params["nd"],
29-
"patch_size": (params["patch_size"],) * params["nd"],
30-
"hidden_size": params["hidden_size"],
31-
"mlp_dim": params["mlp_dim"],
32-
"num_layers": params["num_layers"],
33-
"num_heads": params["num_heads"],
34-
"proj_type": params["proj_type"],
35-
"classification": params["classification"],
36-
"num_classes": params["num_classes"],
37-
"dropout_rate": params["dropout_rate"],
38-
**({"spatial_dims": 2} if params["nd"] == 2 else {}),
39-
**({"post_activation": False} if params["nd"] == 2 and params["classification"] else {}),
40-
},
41-
(2, params["in_channels"], *([params["img_size"]] * params["nd"])),
42-
(
43-
(2, params["num_classes"])
44-
if params["classification"]
45-
else (2, (params["img_size"] // params["patch_size"]) ** params["nd"], params["hidden_size"])
46-
),
47-
]
48-
)
24+
([
25+
{
26+
**{k: v for k, v in params.items() if k not in ["nd"]},
27+
**({"spatial_dims": 2} if params["nd"] == 2 else {}),
28+
**({"post_activation": False} if params["nd"] == 2 and params["classification"] else {}),
29+
},
30+
(2, params["in_channels"], *([params["img_size"]] * params["nd"])),
31+
(
32+
(2, params["num_classes"])
33+
if params["classification"]
34+
else (2, (params["img_size"] // params["patch_size"]) ** params["nd"], params["hidden_size"])
35+
),
36+
])
4937
for params in dict_product(
5038
dropout_rate=[0.6],
5139
in_channels=[4],
@@ -72,15 +60,13 @@ def test_shape(self, input_param, input_shape, expected_shape):
7260
result, _ = net(torch.randn(input_shape))
7361
self.assertEqual(result.shape, expected_shape)
7462

75-
@parameterized.expand(
76-
[
77-
(1, (128, 128, 128), (16, 16, 16), 128, 3072, 12, 12, "conv", False, 5.0),
78-
(1, (32, 32, 32), (64, 64, 64), 512, 3072, 12, 8, "perceptron", False, 0.3),
79-
(1, (96, 96, 96), (8, 8, 8), 512, 3072, 12, 14, "conv", False, 0.3),
80-
(1, (97, 97, 97), (4, 4, 4), 768, 3072, 12, 8, "perceptron", True, 0.3),
81-
(4, (96, 96, 96), (16, 16, 16), 768, 3072, 12, 12, "perc", False, 0.3),
82-
]
83-
)
63+
@parameterized.expand([
64+
(1, (128, 128, 128), (16, 16, 16), 128, 3072, 12, 12, "conv", False, 5.0),
65+
(1, (32, 32, 32), (64, 64, 64), 512, 3072, 12, 8, "perceptron", False, 0.3),
66+
(1, (96, 96, 96), (8, 8, 8), 512, 3072, 12, 14, "conv", False, 0.3),
67+
(1, (97, 97, 97), (4, 4, 4), 768, 3072, 12, 8, "perceptron", True, 0.3),
68+
(4, (96, 96, 96), (16, 16, 16), 768, 3072, 12, 12, "perc", False, 0.3),
69+
])
8470
def test_ill_arg(
8571
self,
8672
in_channels,

tests/transforms/spatial/test_spatial_resampled.py

Lines changed: 35 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -42,30 +42,26 @@
4242
]
4343

4444
for dst, expct in zip(destinations_3d, expected_3d):
45-
TESTS.extend(
45+
TESTS.extend([
4646
[
47-
[
48-
np.arange(12).reshape((1, 2, 2, 3)) + 1.0, # data
49-
*params["device"],
50-
dst,
51-
{
52-
"dst_keys": "dst_affine",
53-
"dtype": params["dtype"],
54-
"align_corners": params["align"],
55-
"mode": params["interp_mode"],
56-
"padding_mode": params["padding_mode"],
57-
},
58-
expct,
59-
]
60-
for params in dict_product(
61-
device=TEST_DEVICES,
62-
align=[True, False],
63-
dtype=[torch.float32, torch.float64],
64-
interp_mode=["nearest", "bilinear"],
65-
padding_mode=["zeros", "border", "reflection"],
66-
)
47+
np.arange(12).reshape((1, 2, 2, 3)) + 1.0, # data
48+
*params["device"],
49+
dst,
50+
{
51+
**{k: v for k, v in params.items() if k not in ["device", "interp_mode"]},
52+
"dst_keys": "dst_affine",
53+
"padding_mode": "zeros",
54+
},
55+
expct,
6756
]
68-
)
57+
for params in dict_product(
58+
device=TEST_DEVICES,
59+
align_corners=[True, False],
60+
dtype=[torch.float32, torch.float64],
61+
interp_mode=["nearest", "bilinear"],
62+
padding_mode=["zeros", "border", "reflection"],
63+
)
64+
])
6965

7066
destinations_2d = [
7167
torch.tensor([[1.0, 0.0, 0.0], [0.0, -1.0, 1.0], [0.0, 0.0, 1.0]]), # flip the second
@@ -75,29 +71,25 @@
7571
expected_2d = [torch.tensor([[[2.0, 1.0], [4.0, 3.0]]]), torch.tensor([[[3.0, 4.0], [1.0, 2.0]]])]
7672

7773
for dst, expct in zip(destinations_2d, expected_2d):
78-
TESTS.extend(
74+
TESTS += [
7975
[
80-
[
81-
np.arange(4).reshape((1, 2, 2)) + 1.0, # data
82-
*params["device"],
83-
dst,
84-
{
85-
"dst_keys": "dst_affine",
86-
"dtype": params["dtype"],
87-
"align_corners": params["align"],
88-
"mode": params["interp_mode"],
89-
"padding_mode": "zeros",
90-
},
91-
expct,
92-
]
93-
for params in dict_product(
94-
device=TEST_DEVICES,
95-
align=[False, True],
96-
dtype=[torch.float32, torch.float64],
97-
interp_mode=["nearest", "bilinear"],
98-
)
76+
np.arange(4).reshape((1, 2, 2)) + 1.0, # data
77+
*params.pop("device"),
78+
dst,
79+
{
80+
**{k: v for k, v in params.items() if k not in ["align", "interp_mode"]},
81+
"dst_keys": "dst_affine",
82+
"padding_mode": "zeros",
83+
},
84+
expct,
9985
]
100-
)
86+
for params in dict_product(
87+
device=TEST_DEVICES,
88+
align=[False, True],
89+
dtype=[torch.float32, torch.float64],
90+
interp_mode=["nearest", "bilinear"],
91+
)
92+
]
10193

10294

10395
class TestSpatialResample(unittest.TestCase):

tests/transforms/utility/test_splitdimd.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,7 @@
2323
from monai.transforms.utility.dictionary import SplitDimd
2424
from tests.test_utils import TEST_NDARRAYS, assert_allclose, dict_product, make_nifti_image, make_rand_affine
2525

26-
TESTS = [
27-
(params["keepdim"], params["p"], params["update_meta"], params["list_output"])
28-
for params in dict_product(
29-
p=TEST_NDARRAYS, keepdim=[True, False], update_meta=[True, False], list_output=[True, False]
30-
)
31-
]
26+
TESTS = list(dict_product(keepdim=[True, False], p=TEST_NDARRAYS, update_meta=[True, False], list_output=[True, False]))
3227

3328

3429
class TestSplitDimd(unittest.TestCase):
@@ -44,9 +39,8 @@ def setUpClass(cls) -> None:
4439
cls.data = loader(data)
4540

4641
@parameterized.expand(TESTS)
47-
def test_correct(self, keepdim, im_type, update_meta, list_output):
42+
def test_correct(self, keepdim, _, update_meta, list_output):
4843
data = deepcopy(self.data)
49-
data["i"] = im_type(data["i"])
5044
arr = data["i"]
5145
for dim in range(arr.ndim):
5246
out = SplitDimd("i", dim=dim, keepdim=keepdim, update_meta=update_meta, list_output=list_output)(data)

0 commit comments

Comments
 (0)