@@ -4,7 +4,7 @@ from _typeshed import SupportsAllComparisons, SupportsItems
44from collections .abc import Callable , Hashable , Iterable , Sized
55from types import GenericAlias
66from typing import Any , Final , Generic , Literal , NamedTuple , TypedDict , TypeVar , final , overload , type_check_only
7- from typing_extensions import ParamSpec , Self , TypeAlias , disjoint_base
7+ from typing_extensions import Concatenate , ParamSpec , Self , TypeAlias , disjoint_base
88
99__all__ = [
1010 "update_wrapper" ,
@@ -197,37 +197,67 @@ else:
197197
198198@type_check_only
199199class _SingleDispatchCallable (Generic [_P , _T ]):
200- registry : types .MappingProxyType [Any , Callable [_P , _T ]]
201- def dispatch (self , cls : Any ) -> Callable [_P , _T ]: ...
200+ # First argument pf the callables in the registry is the type to dispatch on.
201+ registry : types .MappingProxyType [Any , Callable [Concatenate [Any , _P ], _T ]]
202+ def dispatch (self , cls : type [_S ]) -> Callable [Concatenate [_S , _P ], _T ]: ...
203+ if sys .version_info >= (3 , 11 ):
204+ # @fun.register(complex | str)
205+ # def _(arg, verbose=False): ...
206+ @overload
207+ def register (
208+ self , cls : types .UnionType , func : None = None
209+ ) -> Callable [[Callable [Concatenate [_S , _P ], _T ]], Callable [Concatenate [_S , _P ], _T ]]: ...
202210 # @fun.register(complex)
203211 # def _(arg, verbose=False): ...
204212 @overload
205- def register (self , cls : _RegType , func : None = None ) -> Callable [[Callable [_P , _T ]], Callable [_P , _T ]]: ...
213+ def register (
214+ self , cls : type [_S ], func : None = None
215+ ) -> Callable [[Callable [Concatenate [_S , _P ], _T ]], Callable [Concatenate [_S , _P ], _T ]]: ...
206216 # @fun.register
207217 # def _(arg: int, verbose=False):
208218 @overload
209- def register (self , cls : Callable [_P , _T ], func : None = None ) -> Callable [_P , _T ]: ...
219+ def register (self , cls : Callable [Concatenate [_S , _P ], _T ], func : None = None ) -> Callable [Concatenate [_S , _P ], _T ]: ...
220+ if sys .version_info >= (3 , 11 ):
221+ # fun.register(int, lambda x: x)
222+ @overload
223+ def register (
224+ self , cls : types .UnionType , func : Callable [Concatenate [_S , _P ], _T ]
225+ ) -> Callable [Concatenate [_S , _P ], _T ]: ...
210226 # fun.register(int, lambda x: x)
211227 @overload
212- def register (self , cls : _RegType , func : Callable [_P , _T ]) -> Callable [_P , _T ]: ...
228+ def register (self , cls : _RegType , func : Callable [Concatenate [ _S , _P ] , _T ]) -> Callable [Concatenate [ _S , _P ] , _T ]: ...
213229 def _clear_cache (self ) -> None : ...
214- def __call__ (self , / , * args : _P .args , ** kwargs : _P .kwargs ) -> _T : ...
230+ def __call__ (self , arg : object , / , * args : _P .args , ** kwargs : _P .kwargs ) -> _T : ...
215231
216- def singledispatch (func : Callable [_P , _T ]) -> _SingleDispatchCallable [_P , _T ]: ...
232+ def singledispatch (func : Callable [Concatenate [ object , _P ] , _T ]) -> _SingleDispatchCallable [_P , _T ]: ...
217233
218234class singledispatchmethod (Generic [_P , _T ]):
219235 dispatcher : _SingleDispatchCallable [_P , _T ]
220236 func : Callable [_P , _T ]
221- def __init__ (self , func : Callable [_P , _T ]) -> None : ...
237+ def __init__ (self , func : Callable [Concatenate [ object , _P ] , _T ]) -> None : ...
222238 @property
223239 def __isabstractmethod__ (self ) -> bool : ...
240+ if sys .version_info >= (3 , 11 ):
241+ @overload
242+ def register (
243+ self , cls : types .UnionType , method : None = None
244+ ) -> Callable [[Callable [Concatenate [_S , _P ], _T ]], Callable [Concatenate [_S , _P ], _T ]]: ...
245+
224246 @overload
225- def register (self , cls : _RegType , method : None = None ) -> Callable [[Callable [_P , _T ]], Callable [_P , _T ]]: ...
247+ def register (
248+ self , cls : type [_S ], method : None = None
249+ ) -> Callable [[Callable [Concatenate [_S , _P ], _T ]], Callable [Concatenate [_S , _P ], _T ]]: ...
226250 @overload
227- def register (self , cls : Callable [_P , _T ], method : None = None ) -> Callable [_P , _T ]: ...
251+ def register (self , cls : Callable [Concatenate [_S , _P ], _T ], method : None = None ) -> Callable [Concatenate [_S , _P ], _T ]: ...
252+ if sys .version_info >= (3 , 11 ):
253+ @overload
254+ def register (
255+ self , cls : types .UnionType , method : Callable [Concatenate [_S , _P ], _T ]
256+ ) -> Callable [Concatenate [_S , _P ], _T ]: ...
257+
228258 @overload
229- def register (self , cls : _RegType , method : Callable [_P , _T ]) -> Callable [_P , _T ]: ...
230- def __get__ (self , obj : _S , cls : type [_S ] | None = None ) -> Callable [_P , _T ]: ...
259+ def register (self , cls : type [ _S ] , method : Callable [Concatenate [ _S , _P ] , _T ]) -> Callable [Concatenate [ _S , _P ] , _T ]: ...
260+ def __get__ (self , obj : _S , cls : type [_S ] | None = None ) -> Callable [Concatenate [ _S , _P ] , _T ]: ...
231261
232262class cached_property (Generic [_T_co ]):
233263 func : Callable [[Any ], _T_co ]
0 commit comments