Skip to content

Commit 8d46a85

Browse files
committed
Use Concatenate
1 parent b054f06 commit 8d46a85

1 file changed

Lines changed: 43 additions & 13 deletions

File tree

stdlib/functools.pyi

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ from _typeshed import SupportsAllComparisons, SupportsItems
44
from collections.abc import Callable, Hashable, Iterable, Sized
55
from types import GenericAlias
66
from typing import Any, Final, Generic, Literal, NamedTuple, TypedDict, TypeVar, final, overload, type_check_only
7-
from typing_extensions import ParamSpec, Self, TypeAlias, disjoint_base
7+
from typing_extensions import Concatenate, ParamSpec, Self, TypeAlias, disjoint_base
88

99
__all__ = [
1010
"update_wrapper",
@@ -197,37 +197,67 @@ else:
197197

198198
@type_check_only
199199
class _SingleDispatchCallable(Generic[_P, _T]):
200-
registry: types.MappingProxyType[Any, Callable[_P, _T]]
201-
def dispatch(self, cls: Any) -> Callable[_P, _T]: ...
200+
# First argument pf the callables in the registry is the type to dispatch on.
201+
registry: types.MappingProxyType[Any, Callable[Concatenate[Any, _P], _T]]
202+
def dispatch(self, cls: type[_S]) -> Callable[Concatenate[_S, _P], _T]: ...
203+
if sys.version_info >= (3, 11):
204+
# @fun.register(complex | str)
205+
# def _(arg, verbose=False): ...
206+
@overload
207+
def register(
208+
self, cls: types.UnionType, func: None = None
209+
) -> Callable[[Callable[Concatenate[_S, _P], _T]], Callable[Concatenate[_S, _P], _T]]: ...
202210
# @fun.register(complex)
203211
# def _(arg, verbose=False): ...
204212
@overload
205-
def register(self, cls: _RegType, func: None = None) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: ...
213+
def register(
214+
self, cls: type[_S], func: None = None
215+
) -> Callable[[Callable[Concatenate[_S, _P], _T]], Callable[Concatenate[_S, _P], _T]]: ...
206216
# @fun.register
207217
# def _(arg: int, verbose=False):
208218
@overload
209-
def register(self, cls: Callable[_P, _T], func: None = None) -> Callable[_P, _T]: ...
219+
def register(self, cls: Callable[Concatenate[_S, _P], _T], func: None = None) -> Callable[Concatenate[_S, _P], _T]: ...
220+
if sys.version_info >= (3, 11):
221+
# fun.register(int, lambda x: x)
222+
@overload
223+
def register(
224+
self, cls: types.UnionType, func: Callable[Concatenate[_S, _P], _T]
225+
) -> Callable[Concatenate[_S, _P], _T]: ...
210226
# fun.register(int, lambda x: x)
211227
@overload
212-
def register(self, cls: _RegType, func: Callable[_P, _T]) -> Callable[_P, _T]: ...
228+
def register(self, cls: _RegType, func: Callable[Concatenate[_S, _P], _T]) -> Callable[Concatenate[_S, _P], _T]: ...
213229
def _clear_cache(self) -> None: ...
214-
def __call__(self, /, *args: _P.args, **kwargs: _P.kwargs) -> _T: ...
230+
def __call__(self, arg: object, /, *args: _P.args, **kwargs: _P.kwargs) -> _T: ...
215231

216-
def singledispatch(func: Callable[_P, _T]) -> _SingleDispatchCallable[_P, _T]: ...
232+
def singledispatch(func: Callable[Concatenate[object, _P], _T]) -> _SingleDispatchCallable[_P, _T]: ...
217233

218234
class singledispatchmethod(Generic[_P, _T]):
219235
dispatcher: _SingleDispatchCallable[_P, _T]
220236
func: Callable[_P, _T]
221-
def __init__(self, func: Callable[_P, _T]) -> None: ...
237+
def __init__(self, func: Callable[Concatenate[object, _P], _T]) -> None: ...
222238
@property
223239
def __isabstractmethod__(self) -> bool: ...
240+
if sys.version_info >= (3, 11):
241+
@overload
242+
def register(
243+
self, cls: types.UnionType, method: None = None
244+
) -> Callable[[Callable[Concatenate[_S, _P], _T]], Callable[Concatenate[_S, _P], _T]]: ...
245+
224246
@overload
225-
def register(self, cls: _RegType, method: None = None) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: ...
247+
def register(
248+
self, cls: type[_S], method: None = None
249+
) -> Callable[[Callable[Concatenate[_S, _P], _T]], Callable[Concatenate[_S, _P], _T]]: ...
226250
@overload
227-
def register(self, cls: Callable[_P, _T], method: None = None) -> Callable[_P, _T]: ...
251+
def register(self, cls: Callable[Concatenate[_S, _P], _T], method: None = None) -> Callable[Concatenate[_S, _P], _T]: ...
252+
if sys.version_info >= (3, 11):
253+
@overload
254+
def register(
255+
self, cls: types.UnionType, method: Callable[Concatenate[_S, _P], _T]
256+
) -> Callable[Concatenate[_S, _P], _T]: ...
257+
228258
@overload
229-
def register(self, cls: _RegType, method: Callable[_P, _T]) -> Callable[_P, _T]: ...
230-
def __get__(self, obj: _S, cls: type[_S] | None = None) -> Callable[_P, _T]: ...
259+
def register(self, cls: type[_S], method: Callable[Concatenate[_S, _P], _T]) -> Callable[Concatenate[_S, _P], _T]: ...
260+
def __get__(self, obj: _S, cls: type[_S] | None = None) -> Callable[Concatenate[_S, _P], _T]: ...
231261

232262
class cached_property(Generic[_T_co]):
233263
func: Callable[[Any], _T_co]

0 commit comments

Comments
 (0)