Skip to content

Commit b7699ce

Browse files
authored
Support for recursive unions (#13)
things like `type IntTree = int | list[IntTree]`
1 parent 8621197 commit b7699ce

4 files changed

Lines changed: 358 additions & 37 deletions

File tree

tests/test_type_eval.py

Lines changed: 196 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,19 @@
1+
import pytest
2+
3+
import collections
14
import textwrap
25
import unittest
3-
from typing import Any, Callable, Generic, List, Literal, Never, Tuple, TypeVar
6+
from typing import (
7+
Any,
8+
Callable,
9+
Generic,
10+
List,
11+
Literal,
12+
Never,
13+
Tuple,
14+
TypeVar,
15+
Union,
16+
)
417

518
from typemap.type_eval import eval_typing
619
from typemap.typing import (
@@ -101,6 +114,15 @@ class F[bool]:
101114
""")
102115

103116

117+
type UnlabeledTree = list[UnlabeledTree]
118+
type IntTree = int | list[IntTree]
119+
type GenericTree[T] = T | list[GenericTree[T]]
120+
type XNode[X, Y] = X | list[YNode[X, Y]]
121+
type YNode[X, Y] = Y | list[XNode[X, Y]]
122+
type XYTree[X, Y] = XNode[X, Y] | YNode[X, Y]
123+
type NestedTree = str | list[NestedTree] | list[IntTree]
124+
125+
104126
class TA:
105127
x: int
106128
y: list[float]
@@ -167,11 +189,127 @@ def test_type_strings_6():
167189
assert d == Literal["bcd"]
168190

169191

170-
def test_type_asdf():
192+
def _is_generic_permutation(t1, t2):
193+
return t1.__origin__ == t2.__origin__ and collections.Counter(
194+
t1.__args__
195+
) == collections.Counter(t2.__args__)
196+
197+
198+
def test_type_from_union_01():
171199
d = eval_typing(FromUnion[int | bool])
172200
arg = FromUnion[int | str]
173201
d = eval_typing(arg)
174-
assert d == tuple[int, str] or d == tuple[str, int]
202+
assert _is_generic_permutation(d, tuple[int, str])
203+
204+
205+
def test_type_from_union_02():
206+
d = eval_typing(FromUnion[UnlabeledTree])
207+
assert _is_generic_permutation(d, tuple[list[UnlabeledTree]])
208+
209+
d = eval_typing(GetArg[d, tuple, 0])
210+
assert d == list[UnlabeledTree]
211+
d = eval_typing(GetArg[d, list, 0])
212+
assert d == list[UnlabeledTree]
213+
d = eval_typing(FromUnion[d])
214+
assert _is_generic_permutation(d, tuple[list[UnlabeledTree]])
215+
216+
217+
def test_type_from_union_03():
218+
d = eval_typing(FromUnion[IntTree])
219+
assert _is_generic_permutation(d, tuple[int, list[IntTree]])
220+
221+
d = eval_typing(GetArg[d, tuple, 1])
222+
assert d == list[IntTree]
223+
d = eval_typing(GetArg[d, list, 0])
224+
assert d == int | list[IntTree]
225+
d = eval_typing(FromUnion[d])
226+
assert _is_generic_permutation(d, tuple[int, list[IntTree]])
227+
228+
229+
def test_type_from_union_04():
230+
d = eval_typing(FromUnion[GenericTree[str]])
231+
assert _is_generic_permutation(d, tuple[str, list[GenericTree[str]]])
232+
233+
d = eval_typing(GetArg[d, tuple, 1])
234+
assert d == list[GenericTree[str]]
235+
d = eval_typing(GetArg[d, list, 0])
236+
assert d == str | list[GenericTree[str]]
237+
d = eval_typing(FromUnion[d])
238+
assert _is_generic_permutation(d, tuple[str, list[GenericTree[str]]])
239+
240+
241+
def test_type_from_union_05():
242+
d = eval_typing(FromUnion[XYTree[int, str]])
243+
assert _is_generic_permutation(
244+
d,
245+
tuple[XNode[int, str], YNode[int, str]],
246+
)
247+
248+
x = eval_typing(GetArg[d, tuple, 0])
249+
assert x == int | list[str | list[XNode[int, str]]]
250+
251+
x = eval_typing(FromUnion[x])
252+
assert _is_generic_permutation(
253+
x, tuple[int, list[str | list[XNode[int, str]]]]
254+
)
255+
x = eval_typing(GetArg[x, tuple, 1])
256+
assert x == list[str | list[XNode[int, str]]]
257+
x = eval_typing(GetArg[x, list, 0])
258+
assert x == str | list[XNode[int, str]]
259+
x = eval_typing(FromUnion[x])
260+
assert _is_generic_permutation(x, tuple[str, list[XNode[int, str]]])
261+
x = eval_typing(GetArg[x, tuple, 1])
262+
assert x == list[XNode[int, str]]
263+
x = eval_typing(GetArg[x, list, 0])
264+
assert x == int | list[str | list[XNode[int, str]]]
265+
266+
y = eval_typing(GetArg[d, tuple, 1])
267+
assert y == str | list[int | list[YNode[int, str]]]
268+
269+
y = eval_typing(FromUnion[y])
270+
assert _is_generic_permutation(
271+
y, tuple[str, list[int | list[YNode[int, str]]]]
272+
)
273+
y = eval_typing(GetArg[y, tuple, 1])
274+
assert y == list[int | list[YNode[int, str]]]
275+
y = eval_typing(GetArg[y, list, 0])
276+
assert y == int | list[YNode[int, str]]
277+
y = eval_typing(FromUnion[y])
278+
assert _is_generic_permutation(y, tuple[int, list[YNode[int, str]]])
279+
y = eval_typing(GetArg[y, tuple, 1])
280+
assert y == list[YNode[int, str]]
281+
y = eval_typing(GetArg[y, list, 0])
282+
assert y == str | list[int | list[YNode[int, str]]]
283+
284+
285+
def test_type_from_union_06():
286+
d = eval_typing(FromUnion[NestedTree])
287+
assert _is_generic_permutation(
288+
d,
289+
tuple[str, list[NestedTree], list[IntTree]],
290+
)
291+
292+
n = eval_typing(GetArg[d, tuple, 1])
293+
assert n == list[NestedTree]
294+
n = eval_typing(GetArg[n, list, 0])
295+
assert n == str | list[NestedTree] | list[IntTree]
296+
n = eval_typing(FromUnion[n])
297+
assert _is_generic_permutation(
298+
n, tuple[str, list[NestedTree], list[IntTree]]
299+
)
300+
301+
n = eval_typing(FromUnion[GetArg[GetArg[n, tuple, 1], list, 0]])
302+
assert _is_generic_permutation(
303+
n, tuple[str, list[NestedTree], list[IntTree]]
304+
)
305+
306+
i = eval_typing(GetArg[d, tuple, 2])
307+
assert i == list[IntTree]
308+
i = eval_typing(GetArg[i, list, 0])
309+
assert i == int | list[IntTree]
310+
311+
n = eval_typing(FromUnion[GetArg[GetArg[d, tuple, 2], list, 0]])
312+
assert _is_generic_permutation(n, tuple[int, list[IntTree]])
175313

176314

177315
def test_getarg_never():
@@ -330,6 +468,18 @@ def test_eval_getarg_list():
330468
assert arg == Never
331469

332470

471+
@pytest.mark.xfail(reason="Should this work?")
472+
def test_eval_getarg_union_01():
473+
arg = eval_typing(GetArg[int | str, Union, 0])
474+
assert arg is int
475+
476+
477+
@pytest.mark.xfail(reason="Should this work?")
478+
def test_eval_getarg_union_02():
479+
arg = eval_typing(GetArg[GenericTree[int], GenericTree, 0])
480+
assert arg is int
481+
482+
333483
def test_eval_getarg_custom_01():
334484
class A[T]:
335485
pass
@@ -394,6 +544,49 @@ class A(Generic[T]):
394544
assert eval_typing(GetArg[t, A, 1]) == Never
395545

396546

547+
@pytest.mark.xfail(reason="Should this work?")
548+
def test_eval_getarg_custom_05():
549+
A = TypeVar("A")
550+
551+
class ATree(Generic[A]):
552+
val: A | list[ATree[A]]
553+
554+
t = ATree[int]
555+
assert eval_typing(GetArg[t, ATree, 0]) is int
556+
assert eval_typing(GetArg[t, ATree, -1]) is int
557+
assert eval_typing(GetArg[t, ATree, 1]) == Never
558+
559+
t = ATree
560+
assert eval_typing(GetArg[t, ATree, 0]) is Any
561+
assert eval_typing(GetArg[t, ATree, -1]) is Any
562+
assert eval_typing(GetArg[t, ATree, 1]) == Never
563+
564+
565+
@pytest.mark.xfail(reason="Should this work?")
566+
def test_eval_getarg_custom_06():
567+
A = TypeVar("A")
568+
B = TypeVar("B")
569+
570+
class ANode(Generic[A, B]):
571+
val: A | list[BNode[A, B]]
572+
573+
class BNode(Generic[A, B]):
574+
val: B | list[ANode[A, B]]
575+
576+
class ABTree(Generic[A, B]):
577+
root: ANode[A, B] | BNode[A, B]
578+
579+
t = ABTree[int, str]
580+
assert eval_typing(GetArg[t, ABTree, 0]) is int
581+
assert eval_typing(GetArg[t, ABTree, 1]) is str
582+
assert eval_typing(GetArg[t, ABTree, 2]) == Never
583+
584+
t = ABTree
585+
assert eval_typing(GetArg[t, ABTree, 0]) is Any
586+
assert eval_typing(GetArg[t, ABTree, 1]) is Any
587+
assert eval_typing(GetArg[t, ABTree, 2]) == Never
588+
589+
397590
def test_uppercase_never():
398591
d = eval_typing(Uppercase[Never])
399592
assert d is Never

typemap/type_eval/_eval_call.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,10 @@ def _eval_call_with_type_vars(
112112
af.__code__, af.__globals__, af.__name__, None, af_args
113113
)
114114

115-
old_obj = ctx.current_alias
116-
ctx.current_alias = func
115+
old_obj = ctx.current_generic_alias
116+
ctx.current_generic_alias = func
117117
try:
118118
rr = ff(annotationlib.Format.VALUE)
119119
return _eval_typing.eval_typing(rr["return"])
120120
finally:
121-
ctx.current_alias = old_obj
121+
ctx.current_generic_alias = old_obj

typemap/type_eval/_eval_operators.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,10 @@ def _eval_Members(tp, *, ctx):
251251

252252
@type_eval.register_evaluator(FromUnion)
253253
def _eval_FromUnion(tp, *, ctx):
254-
return tuple[*_union_elems(tp, ctx)]
254+
if tp in ctx.known_recursive_types:
255+
return tuple[*_union_elems(ctx.known_recursive_types[tp], ctx)]
256+
else:
257+
return tuple[*_union_elems(tp, ctx)]
255258

256259

257260
##################################################################
@@ -485,12 +488,12 @@ def _eval_NewProtocol(*etyps: Member, ctx):
485488

486489
# If the type evaluation context
487490
ctx = type_eval._get_current_context()
488-
if ctx.current_alias:
489-
if isinstance(ctx.current_alias, types.GenericAlias):
490-
name = str(ctx.current_alias)
491+
if ctx.current_generic_alias:
492+
if isinstance(ctx.current_generic_alias, types.GenericAlias):
493+
name = str(ctx.current_generic_alias)
491494
else:
492-
name = f"{ctx.current_alias.__name__}[...]"
493-
module_name = ctx.current_alias.__module__
495+
name = f"{ctx.current_generic_alias.__name__}[...]"
496+
module_name = ctx.current_generic_alias.__module__
494497

495498
dct["__module__"] = module_name
496499

0 commit comments

Comments
 (0)