Skip to content

Commit 6fec515

Browse files
committed
Move function type construction to get_local_defns.
1 parent aaf48cb commit 6fec515

2 files changed

Lines changed: 53 additions & 34 deletions

File tree

typemap/type_eval/_apply_generic.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
if typing.TYPE_CHECKING:
1616
from typing import Any, Mapping
17+
from typemap.typing import GenericCallable, Overloaded
1718

1819

1920
@dataclasses.dataclass(frozen=True)
@@ -300,10 +301,18 @@ def get_local_defns(
300301
) -> tuple[
301302
dict[str, Any],
302303
dict[
303-
str, types.FunctionType | classmethod | staticmethod | WrappedOverloads
304+
str,
305+
type[
306+
typing.Callable
307+
| classmethod
308+
| staticmethod
309+
| GenericCallable
310+
| Overloaded
311+
],
304312
],
305313
]:
306-
from typemap.typing import GenericCallable
314+
from typemap.typing import GenericCallable, Overloaded
315+
from ._eval_operators import _function_type
307316

308317
annos: dict[str, Any] = {}
309318
dct: dict[str, Any] = {}
@@ -315,6 +324,9 @@ def get_local_defns(
315324
if name in EXCLUDED_ATTRIBUTES:
316325
continue
317326

327+
if orig is typing._no_init_or_replace_init: # type: ignore[attr-defined]
328+
continue
329+
318330
stuff = inspect.unwrap(orig)
319331

320332
if isinstance(stuff, types.FunctionType):
@@ -381,7 +393,36 @@ def lam(*vs):
381393
elif orig.__class__ is staticmethod:
382394
local_fn = staticmethod(local_fn)
383395

384-
dct[name] = local_fn
396+
if isinstance(
397+
local_fn,
398+
(
399+
types.FunctionType,
400+
types.MethodType,
401+
staticmethod,
402+
classmethod,
403+
),
404+
):
405+
dct[name] = _function_type(
406+
local_fn, receiver_type=boxed.alias_type()
407+
)
408+
409+
elif isinstance(local_fn, WrappedOverloads):
410+
overload_types: typing.Sequence[
411+
type[
412+
typing.Callable
413+
| classmethod
414+
| staticmethod
415+
| GenericCallable
416+
]
417+
] = [
418+
_function_type(
419+
_eval_typing.eval_typing(of),
420+
receiver_type=boxed.alias_type(),
421+
)
422+
for of in local_fn.functions
423+
]
424+
425+
dct[name] = Overloaded[*overload_types] # type: ignore[valid-type]
385426

386427
return annos, dct
387428

typemap/type_eval/_eval_operators.py

Lines changed: 9 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
Member,
4242
Members,
4343
NewProtocol,
44-
Overloaded,
4544
Param,
4645
RaiseError,
4746
Slice,
@@ -154,35 +153,12 @@ def get_annotated_method_hints(cls, *, ctx):
154153

155154
_, dct = _apply_generic.get_local_defns(abox)
156155
for name, attr in dct.items():
157-
if isinstance(
156+
hints[name] = (
158157
attr,
159-
(
160-
types.FunctionType,
161-
types.MethodType,
162-
staticmethod,
163-
classmethod,
164-
),
165-
):
166-
if attr is typing._no_init_or_replace_init:
167-
continue
168-
169-
hints[name] = (
170-
_function_type(attr, receiver_type=acls),
171-
("ClassVar",),
172-
object,
173-
acls,
174-
)
175-
elif isinstance(attr, _apply_generic.WrappedOverloads):
176-
overloads = [
177-
_function_type(_eval_types(of, ctx), receiver_type=acls)
178-
for of in attr.functions
179-
]
180-
hints[name] = (
181-
Overloaded[*overloads],
182-
("ClassVar",),
183-
object,
184-
acls,
185-
)
158+
("ClassVar",),
159+
object,
160+
acls,
161+
)
186162

187163
return hints
188164

@@ -763,7 +739,9 @@ def _ann(x):
763739
return f
764740

765741

766-
def _function_type(func, *, receiver_type):
742+
def _function_type(
743+
func, *, receiver_type
744+
) -> type[typing.Callable | classmethod | staticmethod | GenericCallable]:
767745
root = inspect.unwrap(func)
768746
sig = inspect.signature(root)
769747
f = _function_type_from_sig(sig, func, receiver_type=receiver_type)
@@ -772,7 +750,7 @@ def _function_type(func, *, receiver_type):
772750
# Must store a lambda that performs type variable substitution
773751
type_params = root.__type_params__
774752
callable_lambda = _create_generic_callable_lambda(f, type_params)
775-
f = GenericCallable[tuple[*type_params], callable_lambda]
753+
f = GenericCallable[tuple[*type_params], callable_lambda] # type: ignore[misc,valid-type]
776754
return f
777755

778756

0 commit comments

Comments
 (0)