Skip to content

Commit c699d0d

Browse files
committed
Use resolved function signature instead of making a new function.
1 parent 17e6bc6 commit c699d0d

1 file changed

Lines changed: 53 additions & 60 deletions

File tree

typemap/type_eval/_apply_generic.py

Lines changed: 53 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -286,8 +286,10 @@ def get_annotations(
286286
return rr
287287

288288

289-
def _resolved_function_signature(func, args):
290-
"""Get the signature of a function with type hints resolved to arg values"""
289+
def _resolved_function_signature(
290+
func, args, definition_cls: type | None = None
291+
):
292+
"""Get the signature of a function with hints resolved to arg values."""
291293

292294
import typemap.typing as nt
293295

@@ -306,7 +308,7 @@ def _resolved_function_signature(func, args):
306308
finally:
307309
nt.special_form_evaluator.reset(token)
308310

309-
if hints := get_annotations(func, args):
311+
if hints := get_annotations(func, args, cls=definition_cls):
310312
params = []
311313
for name, param in sig.parameters.items():
312314
annotation = hints.get(name, param.annotation)
@@ -336,7 +338,7 @@ def get_local_defns(
336338
],
337339
]:
338340
from typemap.typing import GenericCallable, Overloaded
339-
from ._eval_operators import _function_type
341+
from ._eval_operators import _function_type, _function_type_from_sig
340342

341343
annos: dict[str, Any] = {}
342344
dct: dict[str, Any] = {}
@@ -354,8 +356,6 @@ def get_local_defns(
354356
stuff = inspect.unwrap(orig)
355357

356358
if isinstance(stuff, types.FunctionType):
357-
local_fn: Any = None
358-
359359
# TODO: This annos_ok thing is a hack because processing
360360
# __annotations__ on methods broke stuff and I didn't want
361361
# to chase it down yet.
@@ -368,24 +368,28 @@ def get_local_defns(
368368
stuck = True
369369
rr = None
370370

371+
resolved_sig = None
371372
if rr is not None:
372-
local_fn = make_func(orig, rr)
373+
resolved_sig = _resolved_function_signature(
374+
stuff, boxed.str_args, definition_cls=boxed.cls
375+
)
373376
elif not stuck and getattr(stuff, "__annotations__", None):
374377
# XXX: This is totally wrong; we still need to do
375378
# substitute in class vars
376-
local_fn = stuff
377-
elif overloads := typing.get_overloads(stuff):
378-
local_fn = WrappedOverloads(tuple(overloads))
379-
380-
# If we got stuck, we build a GenericCallable that
381-
# computes the type once it has been given type
382-
# variables!
383-
if stuck and stuff.__type_params__:
379+
resolved_sig = _resolved_function_signature(
380+
stuff, boxed.str_args, definition_cls=boxed.cls
381+
)
382+
overloads = typing.get_overloads(stuff)
383+
384+
# If the method has type params, we build a GenericCallable
385+
# (in annos only) so that [Z] etc. are preserved in output.
386+
if stuff.__type_params__:
384387
type_params = stuff.__type_params__
385388
str_args = boxed.str_args
386-
canonical_cls = boxed.canonical_cls
389+
receiver_cls = boxed.alias_type()
390+
definition_cls = boxed.canonical_cls
387391

388-
def _make_lambda(fn, o, sa, tp, cls):
392+
def _make_lambda(fn, o, sa, tp, recv_cls, def_cls):
389393
from ._eval_operators import _function_type_from_sig
390394

391395
def lam(*vs):
@@ -397,65 +401,54 @@ def lam(*vs):
397401
strict=True,
398402
)
399403
)
400-
sig = _resolved_function_signature(fn, args)
404+
sig = _resolved_function_signature(
405+
fn, args, definition_cls=def_cls
406+
)
401407
return _function_type_from_sig(
402-
sig, type(o), receiver_type=cls
408+
sig, type(o), receiver_type=recv_cls
403409
)
404410

405411
return lam
406412

407413
gc = GenericCallable[ # type: ignore[valid-type,misc]
408414
tuple[*type_params], # type: ignore[valid-type]
409415
_make_lambda(
410-
stuff, orig, str_args, type_params, canonical_cls
416+
stuff,
417+
orig,
418+
str_args,
419+
type_params,
420+
receiver_cls,
421+
definition_cls,
411422
),
412423
]
413-
annos[name] = typing.ClassVar[gc]
414-
elif local_fn is not None:
415-
if orig.__class__ is classmethod:
416-
local_fn = classmethod(local_fn)
417-
elif orig.__class__ is staticmethod:
418-
local_fn = staticmethod(local_fn)
419-
420-
if isinstance(
421-
local_fn,
422-
(
423-
types.FunctionType,
424-
types.MethodType,
425-
staticmethod,
426-
classmethod,
427-
),
428-
):
429-
dct[name] = _function_type(
430-
local_fn, receiver_type=boxed.alias_type()
431-
)
432-
433-
elif isinstance(local_fn, WrappedOverloads):
434-
overload_types: typing.Sequence[
435-
type[
436-
typing.Callable
437-
| classmethod
438-
| staticmethod
439-
| GenericCallable
440-
]
441-
] = [
442-
_function_type(
443-
_eval_typing.eval_typing(of),
444-
receiver_type=boxed.alias_type(),
445-
)
446-
for of in local_fn.functions
424+
dct[name] = gc
425+
elif resolved_sig is not None:
426+
dct[name] = _function_type_from_sig(
427+
resolved_sig,
428+
type(orig),
429+
receiver_type=boxed.alias_type(),
430+
)
431+
elif overloads:
432+
overload_types: typing.Sequence[
433+
type[
434+
typing.Callable
435+
| classmethod
436+
| staticmethod
437+
| GenericCallable
447438
]
439+
] = [
440+
_function_type(
441+
_eval_typing.eval_typing(of),
442+
receiver_type=boxed.alias_type(),
443+
)
444+
for of in overloads
445+
]
448446

449-
dct[name] = Overloaded[*overload_types] # type: ignore[valid-type]
447+
dct[name] = Overloaded[*overload_types] # type: ignore[valid-type]
450448

451449
return annos, dct
452450

453451

454-
@dataclasses.dataclass(frozen=True)
455-
class WrappedOverloads:
456-
functions: tuple[typing.Callable[..., Any], ...]
457-
458-
459452
def flatten_class_new_proto(cls: type) -> type:
460453
# This is a hacky version of flatten_class that works by using
461454
# NewProtocol on Members!

0 commit comments

Comments
 (0)