Skip to content

Commit 8a39bdd

Browse files
committed
Handle default type params
1 parent fb6fbce commit 8a39bdd

4 files changed

Lines changed: 333 additions & 7 deletions

File tree

spec-draft.rst

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,11 @@ It's important that there be a clearly specified type language for the type-leve
6565

6666
---
6767

68-
* ``GetArg[T, Base, Idx: Literal[str]]`` - returns the type argument number ``Idx`` to ``T`` when interpreted as ``Base``, or ``Never`` if it cannot be. (That is, if we have ``class A(B[C]): ...``, then ``GetArg[A, B, 0] == C`` while ``GetArg[A, A, 0] == Never``)
69-
* ``GetArgs[T, Base]`` - returns a tuple containing all of the type arguments of ``T`` when interpreted as ``Base``, or ``Never`` if it cannot be.
68+
* ``GetArg[T, Base, Idx: Literal[str]]`` - returns the type argument number ``Idx`` to ``T`` when interpreted as ``Base``, or ``Never`` if it cannot be. (That is, if we have ``class A(B[C]): ...``, then ``GetArg[A, B, 0] == C`` while ``GetArg[A, A, 0] == Never``).
69+
Special forms unfortunately require some special handling: the arguments list of a ``Callable`` will be packed in a tuple, and a ``...`` will become ``SpecialFormEllipsis``.
70+
71+
72+
* ``GetArgs[T, Base]`` - returns a tuple containing all of the type arguments of ``T`` when interpreted as ``Base``, or ``Never`` if it cannot be. (TODO: UNIMPLEMENTED)
7073
* ``FromUnion[T]`` - returns a tuple containing all of the union elements, or a 1-ary tuple containing T if it is not a union.
7174

7275

@@ -92,8 +95,6 @@ It's important that there be a clearly specified type language for the type-leve
9295
* ``GetAttr[T, S: Literal[str]]``
9396
TODO: How should GetAttr interact with descriptors/classmethod? I am leaning towards it should apply the descriptor...
9497

95-
# TODO: how to deal with special forms like Callable and tuple[T, ...]
96-
9798
* ``Length[T: tuple]`` - get the length of a tuple as an int literal (...or ``Literal[None]`` if it is unbounded)
9899

99100
String manipulation operations for string Literal types.

tests/test_type_eval.py

Lines changed: 207 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import textwrap
22
import unittest
3-
from typing import Literal, Never, Tuple
3+
from typing import Any, Callable, Generic, List, Literal, Never, Tuple, TypeVar
44

55
from typemap.type_eval import eval_typing
66
from typemap.typing import (
@@ -15,6 +15,7 @@
1515
Length,
1616
Member,
1717
NewProtocol,
18+
SpecialFormEllipsis,
1819
StrConcat,
1920
StrSlice,
2021
Uppercase,
@@ -177,6 +178,211 @@ def test_getarg_never():
177178
assert d is Never
178179

179180

181+
def test_eval_getarg_callable():
182+
# oh hmmmmmmm -- yeah maybe callable could be fully bespoke if we
183+
# disallowed putting Callable here...!
184+
t = Callable[[int, str], str]
185+
args = eval_typing(GetArg[t, Callable, 0])
186+
assert args == tuple[int, str]
187+
188+
t = Callable[int, str]
189+
args = eval_typing(GetArg[t, Callable, 0])
190+
assert args == tuple[int]
191+
192+
t = Callable[[], str]
193+
args = eval_typing(GetArg[t, Callable, 0])
194+
assert args == tuple[()]
195+
196+
t = Callable[..., str]
197+
args = eval_typing(GetArg[t, Callable, 0])
198+
assert args == SpecialFormEllipsis
199+
200+
t = Callable
201+
args = eval_typing(GetArg[t, Callable, 0])
202+
assert args == SpecialFormEllipsis
203+
204+
t = Callable
205+
args = eval_typing(GetArg[t, Callable, 1])
206+
assert args == Any
207+
208+
209+
def test_eval_getarg_tuple():
210+
t = tuple[int, ...]
211+
args = eval_typing(GetArg[t, tuple, 1])
212+
assert args == SpecialFormEllipsis
213+
214+
t = tuple
215+
args = eval_typing(GetArg[t, tuple, 0])
216+
assert args == Any
217+
218+
args = eval_typing(GetArg[t, tuple, 1])
219+
assert args == SpecialFormEllipsis
220+
221+
222+
def test_eval_getarg_list():
223+
t = list[int]
224+
arg = eval_typing(GetArg[t, list, 0])
225+
assert arg is int
226+
227+
t = List[int]
228+
arg = eval_typing(GetArg[t, list, 0])
229+
assert arg is int
230+
231+
t = list
232+
arg = eval_typing(GetArg[t, list, 0])
233+
assert arg == Any
234+
235+
t = List
236+
arg = eval_typing(GetArg[t, list, 0])
237+
assert arg == Any
238+
239+
t = list[int]
240+
arg = eval_typing(GetArg[t, List, 0])
241+
assert arg is int
242+
243+
t = List[int]
244+
arg = eval_typing(GetArg[t, List, 0])
245+
assert arg is int
246+
247+
t = list
248+
arg = eval_typing(GetArg[t, List, 0])
249+
assert arg == Any
250+
251+
t = List
252+
arg = eval_typing(GetArg[t, List, 0])
253+
assert arg == Any
254+
255+
# indexing with -1 equivalent to 0
256+
t = list[int]
257+
arg = eval_typing(GetArg[t, list, -1])
258+
assert arg is int
259+
260+
t = List[int]
261+
arg = eval_typing(GetArg[t, list, -1])
262+
assert arg is int
263+
264+
t = list
265+
arg = eval_typing(GetArg[t, list, -1])
266+
assert arg == Any
267+
268+
t = List
269+
arg = eval_typing(GetArg[t, list, -1])
270+
assert arg == Any
271+
272+
t = list[int]
273+
arg = eval_typing(GetArg[t, List, -1])
274+
assert arg is int
275+
276+
t = List[int]
277+
arg = eval_typing(GetArg[t, List, -1])
278+
assert arg is int
279+
280+
t = list
281+
arg = eval_typing(GetArg[t, List, -1])
282+
assert arg == Any
283+
284+
t = List
285+
arg = eval_typing(GetArg[t, List, -1])
286+
assert arg == Any
287+
288+
# indexing with 1 always fails
289+
t = list[int]
290+
arg = eval_typing(GetArg[t, list, 1])
291+
assert arg == Never
292+
293+
t = List[int]
294+
arg = eval_typing(GetArg[t, list, 1])
295+
assert arg == Never
296+
297+
t = list
298+
arg = eval_typing(GetArg[t, list, 1])
299+
assert arg == Never
300+
301+
t = List
302+
arg = eval_typing(GetArg[t, list, 1])
303+
assert arg == Never
304+
305+
t = list[int]
306+
arg = eval_typing(GetArg[t, List, 1])
307+
assert arg == Never
308+
309+
t = List[int]
310+
arg = eval_typing(GetArg[t, List, 1])
311+
assert arg == Never
312+
313+
t = list
314+
arg = eval_typing(GetArg[t, List, 1])
315+
assert arg == Never
316+
317+
t = List
318+
arg = eval_typing(GetArg[t, List, 1])
319+
assert arg == Never
320+
321+
322+
def test_eval_getarg_custom_01():
323+
class A[T]:
324+
pass
325+
326+
t = A[int]
327+
assert eval_typing(GetArg[t, A, 0]) is int
328+
assert eval_typing(GetArg[t, A, -1]) is int
329+
assert eval_typing(GetArg[t, A, 1]) == Never
330+
331+
t = A
332+
assert eval_typing(GetArg[t, A, 0]) == Any
333+
assert eval_typing(GetArg[t, A, -1]) == Any
334+
assert eval_typing(GetArg[t, A, 1]) == Never
335+
336+
337+
def test_eval_getarg_custom_02():
338+
T = TypeVar("T")
339+
340+
class A(Generic[T]):
341+
pass
342+
343+
t = A[int]
344+
assert eval_typing(GetArg[t, A, 0]) is int
345+
assert eval_typing(GetArg[t, A, -1]) is int
346+
assert eval_typing(GetArg[t, A, 1]) == Never
347+
348+
t = A
349+
assert eval_typing(GetArg[t, A, 0]) == Any
350+
assert eval_typing(GetArg[t, A, -1]) == Any
351+
assert eval_typing(GetArg[t, A, 1]) == Never
352+
353+
354+
def test_eval_getarg_custom_03():
355+
class A[T = str]:
356+
pass
357+
358+
t = A[int]
359+
assert eval_typing(GetArg[t, A, 0]) is int
360+
assert eval_typing(GetArg[t, A, -1]) is int
361+
assert eval_typing(GetArg[t, A, 1]) == Never
362+
363+
t = A
364+
assert eval_typing(GetArg[t, A, 0]) is str
365+
assert eval_typing(GetArg[t, A, -1]) is str
366+
assert eval_typing(GetArg[t, A, 1]) == Never
367+
368+
369+
def test_eval_getarg_custom_04():
370+
T = TypeVar("T", default=str)
371+
372+
class A(Generic[T]):
373+
pass
374+
375+
t = A[int]
376+
assert eval_typing(GetArg[t, A, 0]) is int
377+
assert eval_typing(GetArg[t, A, -1]) is int
378+
assert eval_typing(GetArg[t, A, 1]) == Never
379+
380+
t = A
381+
assert eval_typing(GetArg[t, A, 0]) is str
382+
assert eval_typing(GetArg[t, A, -1]) is str
383+
assert eval_typing(GetArg[t, A, 1]) == Never
384+
385+
180386
def test_uppercase_never():
181387
d = eval_typing(Uppercase[Never])
182388
assert d is Never

typemap/type_eval/_eval_operators.py

Lines changed: 117 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1+
import collections
2+
import collections.abc
3+
import contextlib
14
import functools
25
import inspect
36
import itertools
7+
import re
48
import types
59
import typing
610

@@ -22,6 +26,7 @@
2226
Members,
2327
NewProtocol,
2428
Param,
29+
SpecialFormEllipsis,
2530
StrConcat,
2631
StrSlice,
2732
Uncapitalize,
@@ -266,15 +271,125 @@ def _get_args(tp, base, ctx) -> typing.Any:
266271
return None
267272

268273

274+
def _fix_type(tp):
275+
"""Fix up a type getting returned from GetArg
276+
277+
In particular, this means turning a list into a tuple of the list
278+
elements and turning ... into SpecialFormEllipsis.
279+
"""
280+
if isinstance(tp, (tuple, list)):
281+
return tuple[*tp]
282+
elif tp is ...:
283+
return SpecialFormEllipsis
284+
else:
285+
return tp
286+
287+
288+
# The number of generic parameters to all the builtin types that had
289+
# subscripting added in PEP 585.
290+
_BUILTIN_GENERIC_ARITIES = {
291+
tuple: 2, # variadic, like Callable...
292+
list: 1,
293+
dict: 2,
294+
set: 1,
295+
frozenset: 1,
296+
type: 1,
297+
collections.deque: 1,
298+
collections.defaultdict: 2,
299+
collections.OrderedDict: 2,
300+
collections.Counter: 1,
301+
collections.ChainMap: 2,
302+
collections.abc.Awaitable: 1,
303+
collections.abc.Coroutine: 3,
304+
collections.abc.AsyncIterable: 1,
305+
collections.abc.AsyncIterator: 1,
306+
collections.abc.AsyncGenerator: 2,
307+
collections.abc.Iterable: 1,
308+
collections.abc.Iterator: 1,
309+
collections.abc.Generator: 3,
310+
collections.abc.Reversible: 1,
311+
collections.abc.Container: 1,
312+
collections.abc.Collection: 1,
313+
collections.abc.Callable: 2, # special syntax
314+
collections.abc.Set: 1,
315+
collections.abc.MutableSet: 1,
316+
collections.abc.Mapping: 2,
317+
collections.abc.MutableMapping: 2,
318+
collections.abc.Sequence: 1,
319+
collections.abc.MutableSequence: 1,
320+
collections.abc.KeysView: 1,
321+
collections.abc.ItemsView: 2,
322+
collections.abc.ValuesView: 1,
323+
contextlib.AbstractContextManager: 1,
324+
contextlib.AbstractAsyncContextManager: 1,
325+
re.Pattern: 1,
326+
re.Match: 1,
327+
}
328+
329+
330+
def _get_params(base_head):
331+
if (params := getattr(base_head, "__parameters__", None)) is not None:
332+
return params
333+
elif (params := getattr(base_head, "__type_params__", None)) is not None:
334+
return params
335+
else:
336+
return None
337+
338+
339+
def _get_generic_arity(base_head):
340+
if (n := _BUILTIN_GENERIC_ARITIES.get(base_head)) is not None:
341+
return n
342+
# XXX: check the type?
343+
elif (n := getattr(base_head, "_nparams", None)) is not None:
344+
return n
345+
elif (params := _get_params(base_head)) is not None:
346+
# TODO: also check for TypeVarTuple!
347+
return len(params)
348+
else:
349+
return -1
350+
351+
352+
def _get_defaults(base_head):
353+
"""Get the *default* type params for a type
354+
355+
`list` is equivalent to `list[Any]`, so `GetArg[list, list, 0]
356+
ought to return `Any`, while `GetArg[list, list, 1]` ought to
357+
return `Never` because the index is invalid.
358+
359+
Annoyingly we need to consult a table for built-in arities for this.
360+
"""
361+
arity = _get_generic_arity(base_head)
362+
if arity < 0:
363+
return None
364+
365+
# Callable and tuple need to produce a SpecialFormEllipsis for arg
366+
# 0 and 1, respectively.
367+
if base_head is collections.abc.Callable:
368+
return (SpecialFormEllipsis, typing.Any)
369+
elif base_head is tuple:
370+
return (typing.Any, SpecialFormEllipsis)
371+
372+
if params := _get_params(base_head):
373+
return tuple(
374+
typing.Any if t.__default__ == typing.NoDefault else t.__default__
375+
for t in params
376+
)
377+
378+
return (typing.Any,) * arity
379+
380+
269381
@type_eval.register_evaluator(GetArg)
270382
@_lift_over_unions
271383
def _eval_GetArg(tp, base, idx, *, ctx) -> typing.Any:
272-
args = _get_args(tp, base, ctx)
384+
base_head = _typing_inspect.get_head(base)
385+
args = _get_args(tp, base_head, ctx)
386+
if args == ():
387+
args = _get_defaults(base_head)
273388
if args is None:
274389
return typing.Never
275390

276391
try:
277-
return args[_from_literal(idx, ctx)]
392+
return _fix_type(args[_from_literal(idx, ctx)])
278393
except IndexError:
279394
return typing.Never
280395

0 commit comments

Comments
 (0)