Skip to content

Commit abc47ae

Browse files
committed
Move function type construction to get_local_defns.
1 parent 9f27f3f commit abc47ae

2 files changed

Lines changed: 53 additions & 34 deletions

File tree

typemap/type_eval/_apply_generic.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
if typing.TYPE_CHECKING:
1717
from typing import Any, Mapping
18+
from typemap.typing import GenericCallable, Overloaded
1819

1920

2021
@dataclasses.dataclass(frozen=True)
@@ -324,10 +325,18 @@ def get_local_defns(
324325
) -> tuple[
325326
dict[str, Any],
326327
dict[
327-
str, types.FunctionType | classmethod | staticmethod | WrappedOverloads
328+
str,
329+
type[
330+
typing.Callable
331+
| classmethod
332+
| staticmethod
333+
| GenericCallable
334+
| Overloaded
335+
],
328336
],
329337
]:
330-
from typemap.typing import GenericCallable
338+
from typemap.typing import GenericCallable, Overloaded
339+
from ._eval_operators import _function_type
331340

332341
annos: dict[str, Any] = {}
333342
dct: dict[str, Any] = {}
@@ -339,6 +348,9 @@ def get_local_defns(
339348
if name in EXCLUDED_ATTRIBUTES:
340349
continue
341350

351+
if orig is typing._no_init_or_replace_init: # type: ignore[attr-defined]
352+
continue
353+
342354
stuff = inspect.unwrap(orig)
343355

344356
if isinstance(stuff, types.FunctionType):
@@ -405,7 +417,36 @@ def lam(*vs):
405417
elif orig.__class__ is staticmethod:
406418
local_fn = staticmethod(local_fn)
407419

408-
dct[name] = local_fn
420+
if isinstance(
421+
local_fn,
422+
(
423+
types.FunctionType,
424+
types.MethodType,
425+
staticmethod,
426+
classmethod,
427+
),
428+
):
429+
dct[name] = _function_type(
430+
local_fn, receiver_type=boxed.alias_type()
431+
)
432+
433+
elif isinstance(local_fn, WrappedOverloads):
434+
overload_types: typing.Sequence[
435+
type[
436+
typing.Callable
437+
| classmethod
438+
| staticmethod
439+
| GenericCallable
440+
]
441+
] = [
442+
_function_type(
443+
_eval_typing.eval_typing(of),
444+
receiver_type=boxed.alias_type(),
445+
)
446+
for of in local_fn.functions
447+
]
448+
449+
dct[name] = Overloaded[*overload_types] # type: ignore[valid-type]
409450

410451
return annos, dct
411452

typemap/type_eval/_eval_operators.py

Lines changed: 9 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
Member,
4242
Members,
4343
NewProtocol,
44-
Overloaded,
4544
Param,
4645
RaiseError,
4746
Slice,
@@ -154,35 +153,12 @@ def get_annotated_method_hints(cls, *, ctx):
154153

155154
_, dct = _apply_generic.get_local_defns(abox)
156155
for name, attr in dct.items():
157-
if isinstance(
156+
hints[name] = (
158157
attr,
159-
(
160-
types.FunctionType,
161-
types.MethodType,
162-
staticmethod,
163-
classmethod,
164-
),
165-
):
166-
if attr is typing._no_init_or_replace_init:
167-
continue
168-
169-
hints[name] = (
170-
_function_type(attr, receiver_type=acls),
171-
("ClassVar",),
172-
object,
173-
acls,
174-
)
175-
elif isinstance(attr, _apply_generic.WrappedOverloads):
176-
overloads = [
177-
_function_type(_eval_types(of, ctx), receiver_type=acls)
178-
for of in attr.functions
179-
]
180-
hints[name] = (
181-
Overloaded[*overloads],
182-
("ClassVar",),
183-
object,
184-
acls,
185-
)
158+
("ClassVar",),
159+
object,
160+
acls,
161+
)
186162

187163
return hints
188164

@@ -763,7 +739,9 @@ def _ann(x):
763739
return f
764740

765741

766-
def _function_type(func, *, receiver_type):
742+
def _function_type(
743+
func, *, receiver_type
744+
) -> type[typing.Callable | classmethod | staticmethod | GenericCallable]:
767745
root = inspect.unwrap(func)
768746
sig = inspect.signature(root)
769747
f = _function_type_from_sig(sig, func, receiver_type=receiver_type)
@@ -772,7 +750,7 @@ def _function_type(func, *, receiver_type):
772750
# Must store a lambda that performs type variable substitution
773751
type_params = root.__type_params__
774752
callable_lambda = _create_generic_callable_lambda(f, type_params)
775-
f = GenericCallable[tuple[*type_params], callable_lambda]
753+
f = GenericCallable[tuple[*type_params], callable_lambda] # type: ignore[misc,valid-type]
776754
return f
777755

778756

0 commit comments

Comments
 (0)