Skip to content

Commit 1eaa857

Browse files
authored
Fix resolution of generic types not capturing locally defined types (#21)
When you declare a local generic type, it isn't captured by the Boxed type. This causes an error when trying to lookup the arguments for annotate.
1 parent 3e87bcc commit 1eaa857

2 files changed

Lines changed: 79 additions & 6 deletions

File tree

tests/test_type_eval.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -544,8 +544,27 @@ class A(Generic[T]):
544544
assert eval_typing(GetArg[t, A, 1]) == Never
545545

546546

547-
@pytest.mark.xfail(reason="Should this work?")
547+
TestTypeVar = TypeVar("TestTypeVar")
548+
549+
548550
def test_eval_getarg_custom_05():
551+
# TypeVar declared outside of scope of class
552+
class ATree(Generic[TestTypeVar]):
553+
val: list[ATree[TestTypeVar]]
554+
555+
t = ATree[int]
556+
assert eval_typing(GetArg[t, ATree, 0]) is int
557+
assert eval_typing(GetArg[t, ATree, -1]) is int
558+
assert eval_typing(GetArg[t, ATree, 1]) == Never
559+
560+
t = ATree
561+
assert eval_typing(GetArg[t, ATree, 0]) is Any
562+
assert eval_typing(GetArg[t, ATree, -1]) is Any
563+
assert eval_typing(GetArg[t, ATree, 1]) == Never
564+
565+
566+
def test_eval_getarg_custom_06():
567+
# TypeVar declared inside scope of class
549568
A = TypeVar("A")
550569

551570
class ATree(Generic[A]):
@@ -562,8 +581,8 @@ class ATree(Generic[A]):
562581
assert eval_typing(GetArg[t, ATree, 1]) == Never
563582

564583

565-
@pytest.mark.xfail(reason="Should this work?")
566-
def test_eval_getarg_custom_06():
584+
def test_eval_getarg_custom_07():
585+
# Doubly recursive generic types
567586
A = TypeVar("A")
568587
B = TypeVar("B")
569588

@@ -587,6 +606,30 @@ class ABTree(Generic[A, B]):
587606
assert eval_typing(GetArg[t, ABTree, 2]) == Never
588607

589608

609+
def test_eval_getarg_custom_08():
610+
# Generic class with generic methods
611+
T = TypeVar("T")
612+
613+
class Container(Generic[T]):
614+
data: list[T]
615+
616+
def get[T](self, index: int, default: T) -> int | T: ...
617+
def map[U](self, func: Callable[[int], U]) -> list[U]: ...
618+
def convert[T](self, func: Callable[[int], T]) -> Container2[T]: ...
619+
620+
class Container2[T]: ...
621+
622+
t = Container[int]
623+
assert eval_typing(GetArg[t, Container, 0]) is int
624+
assert eval_typing(GetArg[t, Container, -1]) is int
625+
assert eval_typing(GetArg[t, Container, 1]) == Never
626+
627+
t = Container
628+
assert eval_typing(GetArg[t, Container, 0]) is Any
629+
assert eval_typing(GetArg[t, Container, -1]) is Any
630+
assert eval_typing(GetArg[t, Container, 1]) == Never
631+
632+
590633
def test_uppercase_never():
591634
d = eval_typing(Uppercase[Never])
592635
assert d is Never

typemap/type_eval/_apply_generic.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,12 @@ def __post_init__(self):
2727
object.__setattr__(
2828
self,
2929
"str_args",
30-
{str(k): v for k, v in self.args.items()},
30+
{
31+
# Use __name__ when available instead of str()
32+
# str(TypeVar('A')) returns '~A'
33+
(k.__name__ if hasattr(k, '__name__') else str(k)): v
34+
for k, v in self.args.items()
35+
},
3136
)
3237
object.__setattr__(
3338
self,
@@ -178,18 +183,38 @@ def make_func(
178183
return new_func
179184

180185

186+
def _get_closure_types(af: types.FunctionType) -> dict[str, type]:
187+
# Generate a fallback mapping of closure classes.
188+
# This is needed for locally defined generic types which reference
189+
# themselves in their type annotations.
190+
if not af.__closure__:
191+
return {}
192+
return {
193+
name: variable.cell_contents
194+
for name, variable in zip(
195+
af.__code__.co_freevars, af.__closure__, strict=True
196+
)
197+
if isinstance(variable.cell_contents, type)
198+
}
199+
200+
181201
def _get_local_defns(boxed: Boxed) -> tuple[dict[str, Any], dict[str, Any]]:
182202
annos: dict[str, Any] = {}
183203
dct: dict[str, Any] = {}
184204

185-
if af := getattr(boxed.cls, "__annotate__", None):
205+
if af := typing.cast(
206+
types.FunctionType, getattr(boxed.cls, "__annotate__", None)
207+
):
186208
# Class has annotations, let's resolve generic arguments
187209

210+
closure_types = _get_closure_types(af)
188211
args = tuple(
189212
types.CellType(
190213
boxed.cls.__dict__
191214
if name == "__classdict__"
192215
else boxed.str_args[name]
216+
if name in boxed.str_args
217+
else closure_types[name]
193218
)
194219
for name in af.__code__.co_freevars
195220
)
@@ -221,7 +246,9 @@ def _get_local_defns(boxed: Boxed) -> tuple[dict[str, Any], dict[str, Any]]:
221246
stuff = inspect.unwrap(orig)
222247

223248
if isinstance(stuff, types.FunctionType):
224-
if af := getattr(stuff, "__annotate__", None):
249+
if af := typing.cast(
250+
types.FunctionType, getattr(stuff, "__annotate__", None)
251+
):
225252
params = dict(
226253
zip(
227254
map(str, stuff.__type_params__),
@@ -230,13 +257,16 @@ def _get_local_defns(boxed: Boxed) -> tuple[dict[str, Any], dict[str, Any]]:
230257
)
231258
)
232259

260+
closure_types = _get_closure_types(af)
233261
args = tuple(
234262
types.CellType(
235263
boxed.cls.__dict__
236264
if name == "__classdict__"
237265
else params[name]
238266
if name in params
239267
else boxed.str_args[name]
268+
if name in boxed.str_args
269+
else closure_types[name]
240270
)
241271
for name in af.__code__.co_freevars
242272
)

0 commit comments

Comments
 (0)