Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.

Commit cc1d60c

Browse files
Add tests to vector quantizer layer (#97)
* num_channels and num_residual_channels are not Sequence[int] for added flexibility. * Add tests Signed-off-by: Walter Hugo Lopez Pinaya <ianonimato@hotmail.com> * Add test to check codebook update (#97) Signed-off-by: Walter Hugo Lopez Pinaya <ianonimato@hotmail.com> Signed-off-by: Walter Hugo Lopez Pinaya <ianonimato@hotmail.com> Co-authored-by: Petru-Daniel Tudosiu <petru.daniel@tudosiu.com>
1 parent 10acbd0 commit cc1d60c

1 file changed

Lines changed: 100 additions & 0 deletions

File tree

tests/test_vector_quantizer.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import unittest
13+
14+
import torch
15+
16+
from generative.networks.layers import EMAQuantizer, VectorQuantizer
17+
18+
19+
class TestEMA(unittest.TestCase):
20+
def test_ema_shape(self):
21+
layer = EMAQuantizer(
22+
spatial_dims=2,
23+
num_embeddings=16,
24+
embedding_dim=8,
25+
)
26+
input_shape = (1, 8, 8, 8)
27+
x = torch.randn(input_shape)
28+
layer = layer.train()
29+
outputs = layer(x)
30+
self.assertEqual(outputs[0].shape, input_shape)
31+
self.assertEqual(outputs[2].shape, (1, 8, 8))
32+
33+
layer = layer.eval()
34+
outputs = layer(x)
35+
self.assertEqual(outputs[0].shape, input_shape)
36+
self.assertEqual(outputs[2].shape, (1, 8, 8))
37+
38+
def test_ema_quantize(self):
39+
layer = EMAQuantizer(
40+
spatial_dims=2,
41+
num_embeddings=16,
42+
embedding_dim=8,
43+
)
44+
input_shape = (1, 8, 8, 8)
45+
x = torch.randn(input_shape)
46+
outputs = layer.quantize(x)
47+
self.assertEqual(outputs[0].shape, (64, 8)) # (HxW, C)
48+
self.assertEqual(outputs[1].shape, (64, 16)) # (HxW, E)
49+
self.assertEqual(outputs[2].shape, (1, 8, 8)) # (1, H, W)
50+
51+
def test_ema(self):
52+
layer = EMAQuantizer(spatial_dims=2, num_embeddings=2, embedding_dim=2, epsilon=0, decay=0)
53+
original_weight_0 = layer.embedding.weight[0].clone()
54+
original_weight_1 = layer.embedding.weight[1].clone()
55+
x_0 = original_weight_0
56+
x_0 = x_0.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
57+
x_0 = x_0.repeat(1, 1, 1, 2) + 0.001
58+
59+
x_1 = original_weight_1
60+
x_1 = x_1.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
61+
x_1 = x_1.repeat(1, 1, 1, 2)
62+
63+
x = torch.cat([x_0, x_1], dim=0)
64+
layer = layer.train()
65+
_ = layer(x)
66+
67+
self.assertTrue(all(layer.embedding.weight[0] != original_weight_0))
68+
self.assertTrue(all(layer.embedding.weight[1] == original_weight_1))
69+
70+
71+
class TestVectorQuantizer(unittest.TestCase):
72+
def test_vector_quantizer_shape(self):
73+
layer = VectorQuantizer(
74+
EMAQuantizer(
75+
spatial_dims=2,
76+
num_embeddings=16,
77+
embedding_dim=8,
78+
)
79+
)
80+
input_shape = (1, 8, 8, 8)
81+
x = torch.randn(input_shape)
82+
outputs = layer(x)
83+
self.assertEqual(outputs[1].shape, input_shape)
84+
85+
def test_vector_quantizer_quantize(self):
86+
layer = VectorQuantizer(
87+
EMAQuantizer(
88+
spatial_dims=2,
89+
num_embeddings=16,
90+
embedding_dim=8,
91+
)
92+
)
93+
input_shape = (1, 8, 8, 8)
94+
x = torch.randn(input_shape)
95+
outputs = layer.quantize(x)
96+
self.assertEqual(outputs.shape, (1, 8, 8))
97+
98+
99+
if __name__ == "__main__":
100+
unittest.main()

0 commit comments

Comments
 (0)