Skip to content

Commit b39a993

Browse files
committed
fix: solve incorrect union alignment and enum backing type assumptions
1 parent 2a849b6 commit b39a993

7 files changed

Lines changed: 137 additions & 7 deletions

File tree

libdestruct/common/enum/enum_field_inflater.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def _subscripted_enum_handler(
4242
return None
4343
python_enum = args[0]
4444
backing_type = args[1] if len(args) > 1 else c_int
45-
field = IntEnumField(python_enum, size=backing_type.size)
45+
field = IntEnumField(python_enum, backing_type=backing_type)
4646
return field.inflate
4747

4848

libdestruct/common/enum/int_enum_field.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,28 @@
2121
class IntEnumField(EnumField):
2222
"""A generator for an enum of integers."""
2323

24-
def __init__(self: IntEnumField, enum: type[IntEnum], lenient: bool = True, size: int = 4) -> None:
24+
def __init__(
25+
self: IntEnumField,
26+
enum: type[IntEnum],
27+
lenient: bool = True,
28+
size: int = 4,
29+
backing_type: type | None = None,
30+
) -> None:
2531
"""Initialize the field.
2632
2733
Args:
2834
enum: The enum class.
2935
lenient: Whether the conversion is lenient or not.
30-
size: The size of the field in bytes.
36+
size: The size of the field in bytes (used when backing_type is not provided).
37+
backing_type: The explicit backing type to use. If provided, overrides size.
3138
"""
3239
self.enum = enum
3340
self.lenient = lenient
3441

42+
if backing_type is not None:
43+
self.backing_type = backing_type
44+
return
45+
3546
if not 0 < size <= 8:
3647
raise ValueError("The size of the field must be between 1 and 8 bytes.")
3748

libdestruct/common/struct/struct_impl.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def _inflate_struct_attributes(
9393
)
9494

9595
if explicit_offset is not None:
96+
current_offset += bf_tracker.flush()
9697
if explicit_offset < current_offset:
9798
raise ValueError("Offset must be greater than the current size.")
9899
current_offset = explicit_offset
@@ -191,6 +192,7 @@ def compute_own_size(cls: type[struct_impl], reference_type: type) -> None:
191192

192193
has_explicit_offset = explicit_offset is not None
193194
if has_explicit_offset:
195+
size += bf_tracker.flush()
194196
if explicit_offset < size:
195197
raise ValueError("Offset must be greater than the current size.")
196198
size = explicit_offset

libdestruct/common/union/tagged_union_field.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from libdestruct.common.field import Field
1212
from libdestruct.common.union.union import union
13-
from libdestruct.common.utils import size_of
13+
from libdestruct.common.utils import alignment_of, size_of
1414

1515
if TYPE_CHECKING: # pragma: no cover
1616
from libdestruct.backing.resolver import Resolver
@@ -43,3 +43,7 @@ def inflate(self: TaggedUnionField, resolver: Resolver | None) -> union:
4343
def get_size(self: TaggedUnionField) -> int:
4444
"""Return the size of the union (max of all variant sizes)."""
4545
return max(size_of(variant) for variant in self.variants.values())
46+
47+
def get_alignment(self: TaggedUnionField) -> int:
48+
"""Return the alignment of the union (max of all variant alignments)."""
49+
return max(alignment_of(variant) for variant in self.variants.values())

libdestruct/common/union/union_field.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from libdestruct.common.field import Field
1212
from libdestruct.common.union.union import union
13-
from libdestruct.common.utils import size_of
13+
from libdestruct.common.utils import alignment_of, size_of
1414

1515
if TYPE_CHECKING: # pragma: no cover
1616
from libdestruct.backing.resolver import Resolver
@@ -41,3 +41,7 @@ def inflate(self: UnionField, resolver: Resolver | None) -> union:
4141
def get_size(self: UnionField) -> int:
4242
"""Return the size of the union (max of all variant sizes)."""
4343
return max(size_of(variant) for variant in self.variants.values())
44+
45+
def get_alignment(self: UnionField) -> int:
46+
"""Return the alignment of the union (max of all variant alignments)."""
47+
return max(alignment_of(variant) for variant in self.variants.values())

libdestruct/common/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,17 @@ def alignment_of(item: obj | type[obj]) -> int:
8181
if isinstance(item, type) and "alignment" in item.__dict__ and isinstance(item.__dict__["alignment"], int):
8282
return item.__dict__["alignment"]
8383

84-
# Field descriptors — for array fields, alignment comes from the element type
84+
# Field descriptors — use get_alignment if available, else derive from element type or size
8585
if isinstance(item, Field):
86+
if hasattr(item, "get_alignment"):
87+
return item.get_alignment()
8688
if hasattr(item, "item"):
8789
return alignment_of(item.item)
8890
return _alignment_from_size(item.get_size())
8991
if is_field_bound_method(item):
9092
field = item.__self__
93+
if hasattr(field, "get_alignment"):
94+
return field.get_alignment()
9195
if hasattr(field, "item"):
9296
return alignment_of(field.item)
9397
return _alignment_from_size(field.get_size())

test/scripts/struct_unit_test.py

Lines changed: 106 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99

1010
from typing import Annotated
1111

12-
from libdestruct import array, c_int, c_long, c_short, c_uint, inflater, offset, struct, ptr, ptr_to_self, array_of, enum, enum_of
12+
from libdestruct import array, c_int, c_long, c_short, c_uint, c_ushort, inflater, offset, struct, ptr, ptr_to_self, array_of, enum, enum_of, size_of, bitfield_of
13+
from libdestruct.common.union import union, union_of, tagged_union
1314

1415

1516
class StructMemberCollisionTest(unittest.TestCase):
@@ -419,5 +420,109 @@ class s_t(struct):
419420
self.assertIs(s.__eq__(42), NotImplemented)
420421

421422

423+
class BitfieldExplicitOffsetTest(unittest.TestCase):
424+
"""Explicit offset after bitfields must flush the pending bitfield group first."""
425+
426+
def test_offset_after_bitfield_size(self):
427+
"""Struct size must be correct when offset() follows bitfield fields."""
428+
class s_t(struct):
429+
a: c_uint = bitfield_of(c_uint, 1)
430+
b: c_int = offset(8)
431+
432+
# a is a 1-bit bitfield in a 4-byte c_uint group at offset 0.
433+
# b is at explicit offset 8 with size 4.
434+
# Total size: 8 + 4 = 12
435+
self.assertEqual(size_of(s_t), 12)
436+
437+
def test_offset_after_bitfield_read(self):
438+
"""Values must be read correctly when offset() follows bitfield fields."""
439+
import struct as pystruct
440+
441+
class s_t(struct):
442+
a: c_uint = bitfield_of(c_uint, 1)
443+
b: c_int = offset(8)
444+
445+
memory = bytearray(12)
446+
memory[0:4] = pystruct.pack("<I", 1) # a = 1
447+
memory[8:12] = pystruct.pack("<i", 42) # b = 42
448+
449+
s = s_t.from_bytes(memory)
450+
self.assertEqual(s.a.value, 1)
451+
self.assertEqual(s.b.value, 42)
452+
453+
454+
class UnionAlignmentTest(unittest.TestCase):
455+
"""Union fields in aligned structs must use member-derived alignment."""
456+
457+
def test_plain_union_alignment_non_power_of_two_size(self):
458+
"""Union of a 12-byte packed struct and c_long: size=12 but alignment must be 8 (from c_long)."""
459+
class triple_t(struct):
460+
a: c_int
461+
b: c_int
462+
c: c_int
463+
464+
# triple_t is a packed 12-byte struct with alignment 1
465+
# c_long is 8 bytes with alignment 8
466+
# union size = 12, but alignment should be 8 (max member alignment)
467+
class s_t(struct):
468+
_aligned_ = True
469+
tag: c_short # 2 bytes, align 2
470+
data: union = union_of({"t": triple_t, "l": c_long})
471+
472+
# tag at offset 0 (2 bytes)
473+
# data alignment = 8 → data at offset 8
474+
# data size = 12
475+
# struct max alignment = 8 → total = _align_offset(20, 8) = 24
476+
self.assertEqual(size_of(s_t), 24)
477+
478+
def test_tagged_union_alignment_non_power_of_two_size(self):
479+
"""Tagged union of a 12-byte packed struct and c_long: alignment must be 8."""
480+
class triple_t(struct):
481+
a: c_int
482+
b: c_int
483+
c: c_int
484+
485+
class s_t(struct):
486+
_aligned_ = True
487+
tag: c_int # 4 bytes, align 4
488+
data: union = tagged_union("tag", {0: triple_t, 1: c_long})
489+
490+
# tag at offset 0 (4 bytes)
491+
# data alignment = 8 → data at offset 8
492+
# data size = 12
493+
# struct max alignment = 8 → total = _align_offset(20, 8) = 24
494+
self.assertEqual(size_of(s_t), 24)
495+
496+
497+
class SubscriptedEnumUnsignedTest(unittest.TestCase):
498+
"""enum[E, unsigned_backing] must preserve signedness."""
499+
500+
def test_enum_unsigned_backing(self):
501+
"""enum[E, c_ushort] should correctly decode values exceeding signed range."""
502+
from enum import IntEnum
503+
504+
class E(IntEnum):
505+
MAX_VAL = 0xFFFF
506+
507+
class s_t(struct):
508+
val: enum[E, c_ushort]
509+
510+
memory = (0xFFFF).to_bytes(2, "little")
511+
s = s_t.from_bytes(memory)
512+
self.assertEqual(s.val.value, E.MAX_VAL)
513+
514+
def test_enum_unsigned_size(self):
515+
"""enum[E, c_ushort] struct should be 2 bytes."""
516+
from enum import IntEnum
517+
518+
class E(IntEnum):
519+
A = 0
520+
521+
class s_t(struct):
522+
val: enum[E, c_ushort]
523+
524+
self.assertEqual(size_of(s_t), 2)
525+
526+
422527
if __name__ == "__main__":
423528
unittest.main()

0 commit comments

Comments
 (0)