Skip to content

Commit 25d7e5b

Browse files
committed
test: add checks for tagged and simple unions
1 parent 628bd37 commit 25d7e5b

1 file changed

Lines changed: 178 additions & 0 deletions

File tree

test/scripts/tagged_union_test.py

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
#
2+
# This file is part of libdestruct (https://github.com/mrindeciso/libdestruct).
3+
# Copyright (c) 2026 Roberto Alessandro Bertolini. All rights reserved.
4+
# Licensed under the MIT license. See LICENSE file in the project root for details.
5+
#
6+
7+
import struct as pystruct
8+
import unittest
9+
10+
from libdestruct import c_float, c_int, c_long, inflater, size_of, struct
11+
from libdestruct.common.union import tagged_union, union, union_of
12+
13+
14+
class TaggedUnionTest(unittest.TestCase):
15+
def test_basic_variant_selection(self):
16+
"""Union selects the correct variant based on discriminator value."""
17+
class msg_t(struct):
18+
type: c_int
19+
payload: union = tagged_union("type", {0: c_int, 1: c_float})
20+
21+
memory = pystruct.pack("<i", 0) + pystruct.pack("<i", 42)
22+
msg = msg_t.from_bytes(memory)
23+
self.assertEqual(msg.payload.value, 42)
24+
25+
def test_different_discriminator_value(self):
26+
"""Different discriminator value selects different variant."""
27+
class msg_t(struct):
28+
type: c_int
29+
payload: union = tagged_union("type", {0: c_int, 1: c_float})
30+
31+
memory = pystruct.pack("<i", 1) + pystruct.pack("<f", 3.14)
32+
msg = msg_t.from_bytes(memory)
33+
self.assertAlmostEqual(msg.payload.value, 3.14, places=2)
34+
35+
def test_union_size_is_max_variant(self):
36+
"""Union size equals the max of all variant sizes."""
37+
class msg_t(struct):
38+
type: c_int
39+
payload: union = tagged_union("type", {0: c_int, 1: c_long})
40+
41+
# c_int(4) + max(c_int(4), c_long(8)) = 12
42+
self.assertEqual(size_of(msg_t), 12)
43+
44+
def test_struct_variant_field_access(self):
45+
"""Struct variant fields are accessible through the union."""
46+
class point_t(struct):
47+
x: c_int
48+
y: c_int
49+
50+
class msg_t(struct):
51+
type: c_int
52+
payload: union = tagged_union("type", {0: c_int, 1: point_t})
53+
54+
# max(c_int=4, point_t=8) = 8; total = 4 + 8 = 12
55+
memory = pystruct.pack("<i", 1) + pystruct.pack("<ii", 10, 20)
56+
msg = msg_t.from_bytes(memory)
57+
self.assertEqual(msg.payload.x.value, 10)
58+
self.assertEqual(msg.payload.y.value, 20)
59+
60+
def test_unknown_discriminator_raises(self):
61+
"""Unknown discriminator value raises ValueError."""
62+
class msg_t(struct):
63+
type: c_int
64+
payload: union = tagged_union("type", {0: c_int})
65+
66+
memory = pystruct.pack("<i", 99) + b"\x00" * 4
67+
with self.assertRaises(ValueError):
68+
msg_t.from_bytes(memory)
69+
70+
def test_union_to_bytes_full_size(self):
71+
"""to_bytes returns the full union-sized region, not just the active variant."""
72+
class msg_t(struct):
73+
type: c_int
74+
payload: union = tagged_union("type", {0: c_int, 1: c_long})
75+
76+
# 4 bytes type + 4 bytes c_int + 4 bytes padding = 12 bytes total
77+
data = pystruct.pack("<i", 0) + pystruct.pack("<i", 42) + b"\xaa\xbb\xcc\xdd"
78+
msg = msg_t.from_bytes(data)
79+
self.assertEqual(len(msg.payload.to_bytes()), 8)
80+
81+
def test_union_write(self):
82+
"""Writing to the active variant updates memory."""
83+
class msg_t(struct):
84+
type: c_int
85+
payload: union = tagged_union("type", {0: c_int, 1: c_float})
86+
87+
memory = bytearray(8)
88+
lib = inflater(memory)
89+
msg = lib.inflate(msg_t, 0)
90+
msg.payload.value = 100
91+
self.assertEqual(msg.payload.value, 100)
92+
93+
def test_variant_property(self):
94+
"""variant property returns the active variant object."""
95+
class msg_t(struct):
96+
type: c_int
97+
payload: union = tagged_union("type", {0: c_int, 1: c_float})
98+
99+
memory = pystruct.pack("<i", 0) + pystruct.pack("<i", 42)
100+
msg = msg_t.from_bytes(memory)
101+
self.assertIsNotNone(msg.payload.variant)
102+
103+
def test_struct_total_size_with_union(self):
104+
"""Struct containing a union has correct total size."""
105+
class msg_t(struct):
106+
type: c_int
107+
payload: union = tagged_union("type", {0: c_int, 1: c_float})
108+
trailer: c_int
109+
110+
# c_int(4) + max(c_int(4), c_float(4)) + c_int(4) = 12
111+
self.assertEqual(size_of(msg_t), 12)
112+
113+
114+
class PlainUnionTest(unittest.TestCase):
115+
def test_plain_union_read_all_variants(self):
116+
"""Plain union inflates all variants at the same offset."""
117+
class packet_t(struct):
118+
data: union = union_of({"i": c_int, "f": c_float})
119+
120+
memory = pystruct.pack("<f", 3.14)
121+
pkt = packet_t.from_bytes(memory)
122+
self.assertAlmostEqual(pkt.data.f.value, 3.14, places=2)
123+
self.assertIsInstance(pkt.data.i.value, int)
124+
125+
def test_plain_union_size(self):
126+
"""Plain union size is max of all variant sizes."""
127+
class packet_t(struct):
128+
data: union = union_of({"i": c_int, "l": c_long})
129+
130+
self.assertEqual(size_of(packet_t), 8)
131+
132+
def test_plain_union_write(self):
133+
"""Writing to a variant of a plain union updates shared memory."""
134+
class packet_t(struct):
135+
data: union = union_of({"i": c_int, "f": c_float})
136+
137+
memory = bytearray(4)
138+
lib = inflater(memory)
139+
pkt = lib.inflate(packet_t, 0)
140+
pkt.data.i.value = 42
141+
self.assertEqual(pkt.data.i.value, 42)
142+
143+
def test_plain_union_struct_variant(self):
144+
"""Plain union can contain struct variants."""
145+
class point_t(struct):
146+
x: c_int
147+
y: c_int
148+
149+
class packet_t(struct):
150+
data: union = union_of({"raw": c_long, "point": point_t})
151+
152+
memory = pystruct.pack("<ii", 10, 20)
153+
pkt = packet_t.from_bytes(memory)
154+
self.assertEqual(pkt.data.point.x.value, 10)
155+
self.assertEqual(pkt.data.point.y.value, 20)
156+
157+
def test_plain_union_to_bytes(self):
158+
"""Plain union to_bytes returns max-size region."""
159+
class packet_t(struct):
160+
data: union = union_of({"i": c_int, "l": c_long})
161+
162+
memory = b"\x01\x02\x03\x04\x05\x06\x07\x08"
163+
pkt = packet_t.from_bytes(memory)
164+
self.assertEqual(len(pkt.data.to_bytes()), 8)
165+
self.assertEqual(pkt.data.to_bytes(), memory)
166+
167+
def test_plain_union_shared_memory(self):
168+
"""All variants of a plain union share the same memory."""
169+
class packet_t(struct):
170+
data: union = union_of({"i": c_int, "f": c_float})
171+
172+
memory = bytearray(4)
173+
lib = inflater(memory)
174+
pkt = lib.inflate(packet_t, 0)
175+
pkt.data.i.value = 42
176+
# Reading as float should reinterpret the same bytes
177+
expected_float = pystruct.unpack("<f", pystruct.pack("<i", 42))[0]
178+
self.assertAlmostEqual(pkt.data.f.value, expected_float)

0 commit comments

Comments
 (0)