Skip to content

Commit 628bd37

Browse files
committed
feat: add support for simple and tagged unions
1 parent 911550e commit 628bd37

9 files changed

Lines changed: 367 additions & 0 deletions

File tree

libdestruct/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from libdestruct.common.enum import enum, enum_of
2020
from libdestruct.common.ptr.ptr import ptr
2121
from libdestruct.common.struct import ptr_to, ptr_to_self, struct
22+
from libdestruct.common.union import tagged_union, union, union_of
2223
from libdestruct.common.utils import size_of
2324
from libdestruct.libdestruct import inflate, inflater
2425

@@ -48,4 +49,7 @@
4849
"ptr_to_self",
4950
"size_of",
5051
"struct",
52+
"tagged_union",
53+
"union",
54+
"union_of",
5155
]
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#
2+
# This file is part of libdestruct (https://github.com/mrindeciso/libdestruct).
3+
# Copyright (c) 2026 Roberto Alessandro Bertolini. All rights reserved.
4+
# Licensed under the MIT license. See LICENSE file in the project root for details.
5+
#
6+
7+
from libdestruct.common.union.tagged_union_of import tagged_union
8+
from libdestruct.common.union.union import union
9+
from libdestruct.common.union.union_of import union_of
10+
11+
__all__ = ["tagged_union", "union", "union_of"]
12+
13+
import libdestruct.common.union.tagged_union_field_inflater
14+
import libdestruct.common.union.union_field_inflater # noqa: F401
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
#
2+
# This file is part of libdestruct (https://github.com/mrindeciso/libdestruct).
3+
# Copyright (c) 2026 Roberto Alessandro Bertolini. All rights reserved.
4+
# Licensed under the MIT license. See LICENSE file in the project root for details.
5+
#
6+
7+
from __future__ import annotations
8+
9+
from typing import TYPE_CHECKING
10+
11+
from libdestruct.common.field import Field
12+
from libdestruct.common.union.union import union
13+
from libdestruct.common.utils import size_of
14+
15+
if TYPE_CHECKING: # pragma: no cover
16+
from libdestruct.backing.resolver import Resolver
17+
from libdestruct.common.obj import obj
18+
19+
20+
class TaggedUnionField(Field):
21+
"""A field descriptor for a tagged union in a struct."""
22+
23+
base_type: type[obj] = union
24+
25+
def __init__(self: TaggedUnionField, discriminator: str, variants: dict[object, type]) -> None:
26+
"""Initialize the tagged union field.
27+
28+
Args:
29+
discriminator: The name of the struct field used as the discriminator.
30+
variants: A mapping from discriminator values to variant types.
31+
"""
32+
self.discriminator = discriminator
33+
self.variants = variants
34+
35+
def inflate(self: TaggedUnionField, resolver: Resolver | None) -> union:
36+
"""Inflate the field (used during size computation with resolver=None).
37+
38+
Args:
39+
resolver: The backing resolver (None during size computation).
40+
"""
41+
return union(resolver, None, self.get_size())
42+
43+
def get_size(self: TaggedUnionField) -> int:
44+
"""Return the size of the union (max of all variant sizes)."""
45+
return max(size_of(variant) for variant in self.variants.values())
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
#
2+
# This file is part of libdestruct (https://github.com/mrindeciso/libdestruct).
3+
# Copyright (c) 2026 Roberto Alessandro Bertolini. All rights reserved.
4+
# Licensed under the MIT license. See LICENSE file in the project root for details.
5+
#
6+
7+
from __future__ import annotations
8+
9+
from typing import TYPE_CHECKING
10+
11+
from libdestruct.common.type_registry import TypeRegistry
12+
from libdestruct.common.union.tagged_union_field import TaggedUnionField
13+
from libdestruct.common.union.union import union
14+
15+
if TYPE_CHECKING: # pragma: no cover
16+
from collections.abc import Callable
17+
18+
from libdestruct.backing.resolver import Resolver
19+
from libdestruct.common.obj import obj
20+
21+
registry = TypeRegistry()
22+
23+
24+
def tagged_union_field_inflater(
25+
field: TaggedUnionField,
26+
_: type[obj],
27+
owner: tuple[obj, type[obj]] | None,
28+
) -> Callable[[Resolver], obj]:
29+
"""Return the inflater for a tagged union field.
30+
31+
During size computation (owner[0] is None), returns field.inflate which
32+
creates a stub with the correct max size.
33+
34+
During actual inflation, returns a closure that reads the discriminator
35+
from the struct instance and inflates the matching variant.
36+
"""
37+
if owner is None or owner[0] is None:
38+
return field.inflate
39+
40+
struct_instance = owner[0]
41+
42+
def inflate_with_discriminator(resolver: Resolver) -> union:
43+
members = object.__getattribute__(struct_instance, "_members")
44+
disc_value = members[field.discriminator].value
45+
46+
if disc_value not in field.variants:
47+
raise ValueError(
48+
f"Unknown discriminator value {disc_value!r} for field '{field.discriminator}'. "
49+
f"Valid values: {list(field.variants.keys())}"
50+
)
51+
52+
variant_type = field.variants[disc_value]
53+
variant_inflater = registry.inflater_for(variant_type)
54+
variant = variant_inflater(resolver)
55+
56+
return union(resolver, variant, field.get_size())
57+
58+
return inflate_with_discriminator
59+
60+
61+
registry.register_instance_handler(TaggedUnionField, tagged_union_field_inflater)
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#
2+
# This file is part of libdestruct (https://github.com/mrindeciso/libdestruct).
3+
# Copyright (c) 2026 Roberto Alessandro Bertolini. All rights reserved.
4+
# Licensed under the MIT license. See LICENSE file in the project root for details.
5+
#
6+
7+
from __future__ import annotations
8+
9+
from libdestruct.common.union.tagged_union_field import TaggedUnionField
10+
11+
12+
def tagged_union(discriminator: str, variants: dict[object, type]) -> TaggedUnionField:
13+
"""Create a tagged union field descriptor.
14+
15+
Args:
16+
discriminator: The name of the struct field used to select the active variant.
17+
variants: A mapping from discriminator values to variant types.
18+
19+
Returns:
20+
A TaggedUnionField for use as a struct field default value.
21+
"""
22+
return TaggedUnionField(discriminator, variants)

libdestruct/common/union/union.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
#
2+
# This file is part of libdestruct (https://github.com/mrindeciso/libdestruct).
3+
# Copyright (c) 2026 Roberto Alessandro Bertolini. All rights reserved.
4+
# Licensed under the MIT license. See LICENSE file in the project root for details.
5+
#
6+
7+
from __future__ import annotations
8+
9+
from typing import TYPE_CHECKING
10+
11+
from libdestruct.common.obj import obj
12+
13+
if TYPE_CHECKING: # pragma: no cover
14+
from libdestruct.backing.resolver import Resolver
15+
16+
17+
class union(obj):
18+
"""A union value, supporting both tagged (single active variant) and plain (all variants overlaid) modes."""
19+
20+
_variant: obj | None
21+
"""The single active variant (tagged union mode)."""
22+
23+
_variants: dict[str, obj]
24+
"""Named variants (plain union mode)."""
25+
26+
_frozen_bytes: bytes | None
27+
"""The frozen bytes of the full union region."""
28+
29+
def __init__(
30+
self: union,
31+
resolver: Resolver | None,
32+
variant: obj | None,
33+
max_size: int,
34+
variants: dict[str, obj] | None = None,
35+
) -> None:
36+
"""Initialize the union.
37+
38+
Args:
39+
resolver: The backing resolver.
40+
variant: The single active variant (tagged union mode, None for plain unions).
41+
max_size: The size of the union (max of all variant sizes).
42+
variants: Named variants dict (plain union mode, None for tagged unions).
43+
"""
44+
super().__init__(resolver)
45+
self._variant = variant
46+
self._variants = variants or {}
47+
self.size = max_size
48+
self._frozen_bytes = None
49+
50+
@property
51+
def variant(self: union) -> obj | None:
52+
"""Return the active variant object (tagged union mode)."""
53+
return self._variant
54+
55+
def get(self: union) -> object:
56+
"""Return the value of the active variant."""
57+
if self._variant is not None:
58+
return self._variant.get()
59+
if self._variants:
60+
return {name: v.get() for name, v in self._variants.items()}
61+
return None
62+
63+
def _set(self: union, value: object) -> None:
64+
"""Set the value of the active variant."""
65+
if self._variant is None:
66+
raise RuntimeError("Cannot set the value of a union without an active variant.")
67+
self._variant._set(value)
68+
69+
def to_bytes(self: union) -> bytes:
70+
"""Return the full union-sized region as bytes."""
71+
if self._frozen_bytes is not None:
72+
return self._frozen_bytes
73+
if self.resolver is None:
74+
return b"\x00" * self.size
75+
return self.resolver.resolve(self.size, 0)
76+
77+
def freeze(self: union) -> None:
78+
"""Freeze the union and all its variants."""
79+
if self.resolver is not None:
80+
self._frozen_bytes = self.resolver.resolve(self.size, 0)
81+
else:
82+
self._frozen_bytes = b"\x00" * self.size
83+
if self._variant is not None:
84+
self._variant.freeze()
85+
for v in self._variants.values():
86+
v.freeze()
87+
super().freeze()
88+
89+
def to_str(self: union, indent: int = 0) -> str:
90+
"""Return a string representation of the union."""
91+
if self._variant is not None:
92+
return self._variant.to_str(indent)
93+
if self._variants:
94+
members = ", ".join(self._variants)
95+
return f"union({members})"
96+
return "union(empty)"
97+
98+
def __getattr__(self: union, name: str) -> object:
99+
"""Delegate attribute access to named variants or the active variant."""
100+
variants = object.__getattribute__(self, "_variants")
101+
if name in variants:
102+
return variants[name]
103+
variant = object.__getattribute__(self, "_variant")
104+
if variant is not None:
105+
return getattr(variant, name)
106+
raise AttributeError(f"'{type(self).__name__}' has no attribute '{name}'")
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#
2+
# This file is part of libdestruct (https://github.com/mrindeciso/libdestruct).
3+
# Copyright (c) 2026 Roberto Alessandro Bertolini. All rights reserved.
4+
# Licensed under the MIT license. See LICENSE file in the project root for details.
5+
#
6+
7+
from __future__ import annotations
8+
9+
from typing import TYPE_CHECKING
10+
11+
from libdestruct.common.field import Field
12+
from libdestruct.common.union.union import union
13+
from libdestruct.common.utils import size_of
14+
15+
if TYPE_CHECKING: # pragma: no cover
16+
from libdestruct.backing.resolver import Resolver
17+
from libdestruct.common.obj import obj
18+
19+
20+
class UnionField(Field):
21+
"""A field descriptor for a plain (non-discriminated) union in a struct."""
22+
23+
base_type: type[obj] = union
24+
25+
def __init__(self: UnionField, variants: dict[str, type]) -> None:
26+
"""Initialize the union field.
27+
28+
Args:
29+
variants: A mapping from variant names to their types.
30+
"""
31+
self.variants = variants
32+
33+
def inflate(self: UnionField, resolver: Resolver | None) -> union:
34+
"""Inflate the field (used during size computation with resolver=None).
35+
36+
Args:
37+
resolver: The backing resolver (None during size computation).
38+
"""
39+
return union(resolver, None, self.get_size())
40+
41+
def get_size(self: UnionField) -> int:
42+
"""Return the size of the union (max of all variant sizes)."""
43+
return max(size_of(variant) for variant in self.variants.values())
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#
2+
# This file is part of libdestruct (https://github.com/mrindeciso/libdestruct).
3+
# Copyright (c) 2026 Roberto Alessandro Bertolini. All rights reserved.
4+
# Licensed under the MIT license. See LICENSE file in the project root for details.
5+
#
6+
7+
from __future__ import annotations
8+
9+
from typing import TYPE_CHECKING
10+
11+
from libdestruct.common.type_registry import TypeRegistry
12+
from libdestruct.common.union.union import union
13+
from libdestruct.common.union.union_field import UnionField
14+
15+
if TYPE_CHECKING: # pragma: no cover
16+
from collections.abc import Callable
17+
18+
from libdestruct.backing.resolver import Resolver
19+
from libdestruct.common.obj import obj
20+
21+
registry = TypeRegistry()
22+
23+
24+
def union_field_inflater(
25+
field: UnionField,
26+
_: type[obj],
27+
owner: tuple[obj, type[obj]] | None,
28+
) -> Callable[[Resolver], obj]:
29+
"""Return the inflater for a plain union field.
30+
31+
During size computation (owner[0] is None), returns field.inflate which
32+
creates a stub with the correct max size.
33+
34+
During actual inflation, returns a closure that inflates all variants
35+
at the same memory location.
36+
"""
37+
if owner is None or owner[0] is None:
38+
return field.inflate
39+
40+
def inflate_all_variants(resolver: Resolver) -> union:
41+
variants = {}
42+
for name, variant_type in field.variants.items():
43+
variant_inflater = registry.inflater_for(variant_type)
44+
variants[name] = variant_inflater(resolver)
45+
46+
return union(resolver, None, field.get_size(), variants=variants)
47+
48+
return inflate_all_variants
49+
50+
51+
registry.register_instance_handler(UnionField, union_field_inflater)
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#
2+
# This file is part of libdestruct (https://github.com/mrindeciso/libdestruct).
3+
# Copyright (c) 2026 Roberto Alessandro Bertolini. All rights reserved.
4+
# Licensed under the MIT license. See LICENSE file in the project root for details.
5+
#
6+
7+
from __future__ import annotations
8+
9+
from libdestruct.common.union.union_field import UnionField
10+
11+
12+
def union_of(variants: dict[str, type]) -> UnionField:
13+
"""Create a plain union field descriptor.
14+
15+
Args:
16+
variants: A mapping from variant names to their types.
17+
18+
Returns:
19+
A UnionField for use as a struct field default value.
20+
"""
21+
return UnionField(variants)

0 commit comments

Comments
 (0)