Skip to content

Commit 6be0bfd

Browse files
committed
Use resolved function signature instead of making a new function.
1 parent a8ce02d commit 6be0bfd

1 file changed

Lines changed: 53 additions & 60 deletions

File tree

typemap/type_eval/_apply_generic.py

Lines changed: 53 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -271,8 +271,10 @@ def get_annotations(
271271
return rr
272272

273273

274-
def _resolved_function_signature(func, args):
275-
"""Get the signature of a function with type hints resolved to arg values"""
274+
def _resolved_function_signature(
275+
func, args, definition_cls: type | None = None
276+
):
277+
"""Get the signature of a function with hints resolved to arg values."""
276278

277279
import typemap.typing as nt
278280

@@ -282,7 +284,7 @@ def _resolved_function_signature(func, args):
282284
finally:
283285
nt.special_form_evaluator.reset(token)
284286

285-
if hints := get_annotations(func, args):
287+
if hints := get_annotations(func, args, cls=definition_cls):
286288
params = []
287289
for name, param in sig.parameters.items():
288290
annotation = hints.get(name, param.annotation)
@@ -312,7 +314,7 @@ def get_local_defns(
312314
],
313315
]:
314316
from typemap.typing import GenericCallable, Overloaded
315-
from ._eval_operators import _function_type
317+
from ._eval_operators import _function_type, _function_type_from_sig
316318

317319
annos: dict[str, Any] = {}
318320
dct: dict[str, Any] = {}
@@ -330,8 +332,6 @@ def get_local_defns(
330332
stuff = inspect.unwrap(orig)
331333

332334
if isinstance(stuff, types.FunctionType):
333-
local_fn: Any = None
334-
335335
# TODO: This annos_ok thing is a hack because processing
336336
# __annotations__ on methods broke stuff and I didn't want
337337
# to chase it down yet.
@@ -344,24 +344,28 @@ def get_local_defns(
344344
stuck = True
345345
rr = None
346346

347+
resolved_sig = None
347348
if rr is not None:
348-
local_fn = make_func(orig, rr)
349+
resolved_sig = _resolved_function_signature(
350+
stuff, boxed.str_args, definition_cls=boxed.cls
351+
)
349352
elif not stuck and getattr(stuff, "__annotations__", None):
350353
# XXX: This is totally wrong; we still need to do
351354
# substitute in class vars
352-
local_fn = stuff
353-
elif overloads := typing.get_overloads(stuff):
354-
local_fn = WrappedOverloads(tuple(overloads))
355-
356-
# If we got stuck, we build a GenericCallable that
357-
# computes the type once it has been given type
358-
# variables!
359-
if stuck and stuff.__type_params__:
355+
resolved_sig = _resolved_function_signature(
356+
stuff, boxed.str_args, definition_cls=boxed.cls
357+
)
358+
overloads = typing.get_overloads(stuff)
359+
360+
# If the method has type params, we build a GenericCallable
361+
# (in annos only) so that [Z] etc. are preserved in output.
362+
if stuff.__type_params__:
360363
type_params = stuff.__type_params__
361364
str_args = boxed.str_args
362-
canonical_cls = boxed.canonical_cls
365+
receiver_cls = boxed.alias_type()
366+
definition_cls = boxed.canonical_cls
363367

364-
def _make_lambda(fn, o, sa, tp, cls):
368+
def _make_lambda(fn, o, sa, tp, recv_cls, def_cls):
365369
from ._eval_operators import _function_type_from_sig
366370

367371
def lam(*vs):
@@ -373,65 +377,54 @@ def lam(*vs):
373377
strict=True,
374378
)
375379
)
376-
sig = _resolved_function_signature(fn, args)
380+
sig = _resolved_function_signature(
381+
fn, args, definition_cls=def_cls
382+
)
377383
return _function_type_from_sig(
378-
sig, type(o), receiver_type=cls
384+
sig, type(o), receiver_type=recv_cls
379385
)
380386

381387
return lam
382388

383389
gc = GenericCallable[ # type: ignore[valid-type,misc]
384390
tuple[*type_params], # type: ignore[valid-type]
385391
_make_lambda(
386-
stuff, orig, str_args, type_params, canonical_cls
392+
stuff,
393+
orig,
394+
str_args,
395+
type_params,
396+
receiver_cls,
397+
definition_cls,
387398
),
388399
]
389-
annos[name] = typing.ClassVar[gc]
390-
elif local_fn is not None:
391-
if orig.__class__ is classmethod:
392-
local_fn = classmethod(local_fn)
393-
elif orig.__class__ is staticmethod:
394-
local_fn = staticmethod(local_fn)
395-
396-
if isinstance(
397-
local_fn,
398-
(
399-
types.FunctionType,
400-
types.MethodType,
401-
staticmethod,
402-
classmethod,
403-
),
404-
):
405-
dct[name] = _function_type(
406-
local_fn, receiver_type=boxed.alias_type()
407-
)
408-
409-
elif isinstance(local_fn, WrappedOverloads):
410-
overload_types: typing.Sequence[
411-
type[
412-
typing.Callable
413-
| classmethod
414-
| staticmethod
415-
| GenericCallable
416-
]
417-
] = [
418-
_function_type(
419-
_eval_typing.eval_typing(of),
420-
receiver_type=boxed.alias_type(),
421-
)
422-
for of in local_fn.functions
400+
dct[name] = gc
401+
elif resolved_sig is not None:
402+
dct[name] = _function_type_from_sig(
403+
resolved_sig,
404+
type(orig),
405+
receiver_type=boxed.alias_type(),
406+
)
407+
elif overloads:
408+
overload_types: typing.Sequence[
409+
type[
410+
typing.Callable
411+
| classmethod
412+
| staticmethod
413+
| GenericCallable
423414
]
415+
] = [
416+
_function_type(
417+
_eval_typing.eval_typing(of),
418+
receiver_type=boxed.alias_type(),
419+
)
420+
for of in overloads
421+
]
424422

425-
dct[name] = Overloaded[*overload_types] # type: ignore[valid-type]
423+
dct[name] = Overloaded[*overload_types] # type: ignore[valid-type]
426424

427425
return annos, dct
428426

429427

430-
@dataclasses.dataclass(frozen=True)
431-
class WrappedOverloads:
432-
functions: tuple[typing.Callable[..., Any], ...]
433-
434-
435428
def flatten_class_new_proto(cls: type) -> type:
436429
# This is a hacky version of flatten_class that works by using
437430
# NewProtocol on Members!

0 commit comments

Comments
 (0)