Skip to content

Commit 3e1aafc

Browse files
committed
Properly handle Quals in NewProtocol and Members
1 parent c6be832 commit 3e1aafc

3 files changed

Lines changed: 72 additions & 23 deletions

File tree

tests/test_type_dir.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,25 @@
11
import textwrap
2-
from typing import Never, Literal, Union, TypeVar
2+
import typing
3+
from typing import Literal, Never, TypeVar, Union
34

45
from typemap.type_eval import eval_typing
56
from typemap.typing import (
6-
NewProtocol,
7-
Member,
7+
Attrs,
8+
FromUnion,
89
GetArg,
910
GetName,
11+
GetQuals,
1012
GetType,
13+
Is,
1114
Iter,
12-
Attrs,
15+
Member,
1316
Members,
14-
FromUnion,
17+
NewProtocol,
1518
Uppercase,
16-
Is,
1719
)
1820

1921
from . import format_helper
2022

21-
2223
type OrGotcha[K] = K | Literal["gotcha!"]
2324

2425
type StrForInt[X] = (str | OrGotcha[X]) if X is int else (X | OrGotcha[X])
@@ -39,6 +40,8 @@ class Base[T]:
3940
t: dict[str, StrForInt[T]]
4041
kkk: K
4142

43+
fin: typing.Final[int]
44+
4245
def foo(self, a: T | None, b: int = 0) -> dict[str, T]:
4346
pass
4447

@@ -75,14 +78,20 @@ class Final(Mine, Ordinary, Wrapper[float], AnotherBase[float], Last[int]):
7578

7679

7780
type AllOptional[T] = NewProtocol[
78-
*[Member[GetName[p], GetType[p] | None] for p in Iter[Attrs[T]]]
81+
*[
82+
Member[GetName[p], GetType[p] | None, GetQuals[p]]
83+
for p in Iter[Attrs[T]]
84+
]
7985
]
8086

8187
type OptionalFinal = AllOptional[Final]
8288

8389

8490
type Capitalize[T] = NewProtocol[
85-
*[Member[Uppercase[GetName[p]], GetType[p]] for p in Iter[Attrs[T]]]
91+
*[
92+
Member[Uppercase[GetName[p]], GetType[p], GetQuals[p]]
93+
for p in Iter[Attrs[T]]
94+
]
8695
]
8796

8897
type Prims[T] = NewProtocol[
@@ -102,6 +111,7 @@ class Final(Mine, Ordinary, Wrapper[float], AnotherBase[float], Last[int]):
102111
if not Is[t, Literal]
103112
]
104113
],
114+
GetQuals[p],
105115
]
106116
for p in Iter[Attrs[T]]
107117
]
@@ -137,6 +147,7 @@ class Final(Mine, Ordinary, Wrapper[float], AnotherBase[float], Last[int]):
137147
if not Is[IsLiteral[t], Literal[True]]
138148
]
139149
],
150+
GetQuals[p],
140151
]
141152
for p in Iter[Attrs[T]]
142153
]
@@ -152,6 +163,7 @@ class Final:
152163
iii: str | int | typing.Literal['gotcha!']
153164
t: dict[str, str | int | typing.Literal['gotcha!']]
154165
kkk: ~K
166+
fin: typing.Final[int]
155167
x: tests.test_type_dir.Wrapper[int | None]
156168
ordinary: str
157169
def foo(self, a: int | None, b: int = 0) -> dict[str, int]: ...
@@ -172,6 +184,7 @@ class AllOptional[tests.test_type_dir.Final]:
172184
iii: str | int | typing.Literal['gotcha!'] | None
173185
t: dict[str, str | int | typing.Literal['gotcha!']] | None
174186
kkk: ~K | None
187+
fin: typing.Final[int | None]
175188
x: tests.test_type_dir.Wrapper[int | None] | None
176189
ordinary: str | None
177190
""")
@@ -186,6 +199,7 @@ class Capitalize[tests.test_type_dir.Final]:
186199
III: str | int | typing.Literal['gotcha!']
187200
T: dict[str, str | int | typing.Literal['gotcha!']]
188201
KKK: ~K
202+
FIN: typing.Final[int]
189203
X: tests.test_type_dir.Wrapper[int | None]
190204
ORDINARY: str
191205
""")
@@ -197,6 +211,7 @@ def test_type_dir_4():
197211
assert format_helper.format_class(d) == textwrap.dedent("""\
198212
class Prims[tests.test_type_dir.Final]:
199213
last: int | typing.Literal[True]
214+
fin: typing.Final[int]
200215
ordinary: str
201216
""")
202217

@@ -211,6 +226,7 @@ class NoLiterals1[tests.test_type_dir.Final]:
211226
iii: str | int
212227
t: dict[str, str | int | typing.Literal['gotcha!']]
213228
kkk: ~K
229+
fin: typing.Final[int]
214230
x: tests.test_type_dir.Wrapper[int | None]
215231
ordinary: str
216232
""")
@@ -225,6 +241,7 @@ class NoLiterals2[tests.test_type_dir.Final]:
225241
iii: str | int
226242
t: dict[str, str | int | typing.Literal['gotcha!']]
227243
kkk: ~K
244+
fin: typing.Final[int]
228245
x: tests.test_type_dir.Wrapper[int | None]
229246
ordinary: str
230247
""")

typemap/type_eval/_eval_operators.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,24 @@ def get_annotated_type_hints(cls, **kwargs):
5656
continue
5757
for k in acls.__annotations__:
5858
if k not in hints:
59-
# XXX: TODO: Strip ClassVar/Final
60-
hints[k] = ohints[k], (), acls
59+
quals = set()
60+
ty = ohints[k]
61+
62+
# Strip ClassVar/Final from ty and add them to quals
63+
while True:
64+
for form in [typing.ClassVar, typing.Final]:
65+
if _typing_inspect.is_special_form(ty, form):
66+
quals.add(form.__name__)
67+
ty = (
68+
typing.get_args(ty)[0]
69+
if typing.get_args(ty)
70+
else typing.Any
71+
)
72+
break
73+
else:
74+
break
75+
76+
hints[k] = ty, tuple(sorted(quals)), acls
6177

6278
# Stop early if we are done.
6379
if len(hints) == len(ohints):
@@ -249,7 +265,6 @@ def _eval_GetAttr(lhs, prop, *, ctx):
249265

250266

251267
def _get_raw_args(tp, base_head, ctx) -> typing.Any:
252-
# XXX: check against base!!
253268
evaled = _eval_types(tp, ctx)
254269

255270
tp_head = _typing_inspect.get_head(tp)
@@ -444,15 +459,23 @@ def func(*args, ctx):
444459
##################################################################
445460

446461

462+
def _add_quals(typ, quals):
463+
for qual in (typing.ClassVar, typing.Final):
464+
if type_eval.issubsimilar(typing.Literal[qual.__name__], quals):
465+
typ = qual[typ]
466+
return typ
467+
468+
447469
@type_eval.register_evaluator(NewProtocol)
448470
def _eval_NewProtocol(*etyps: Member, ctx):
449471
dct: dict[str, object] = {}
450472
dct["__annotations__"] = {
451473
# XXX: Should eval_typing on the etyps evaluate the arguments??
452-
_from_literal(typing.get_args(prop)[0], ctx): _eval_types(
453-
typing.get_args(prop)[1], ctx
474+
_from_literal(name, ctx): _add_quals(
475+
_eval_types(typ, ctx),
476+
_eval_types(quals, ctx),
454477
)
455-
for prop in etyps
478+
for name, typ, quals, _ in (typing.get_args(prop) for prop in etyps)
456479
}
457480

458481
module_name = __name__

typemap/type_eval/_typing_inspect.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,37 @@
44

55

66
import typing
7-
8-
from typing import (
7+
from types import GenericAlias, UnionType
8+
from typing import ( # type: ignore [attr-defined] # noqa: PLC2701
99
Annotated,
1010
Any,
11-
ClassVar,
1211
ForwardRef,
1312
Literal,
1413
TypeGuard,
1514
TypeVar,
1615
Union,
16+
_GenericAlias,
17+
_SpecialGenericAlias,
1718
get_args,
1819
get_origin,
1920
)
20-
from typing import _GenericAlias, _SpecialGenericAlias # type: ignore [attr-defined] # noqa: PLC2701
21+
2122
from typing_extensions import TypeAliasType, TypeVarTuple, Unpack
22-
from types import GenericAlias, UnionType
2323

2424
from . import _eval_typing
2525

2626

27-
def is_classvar(t: Any) -> bool:
28-
return t is ClassVar or (is_generic_alias(t) and get_origin(t) is ClassVar) # type: ignore [comparison-overlap]
27+
def is_special_form(t: Any, form: Any) -> bool:
28+
"""Check if t is a special form or a generic alias of that form.
29+
30+
Args:
31+
t: The type to check
32+
form: The special form to check against (e.g., ClassVar, Final, Literal)
33+
34+
Returns:
35+
True if t is the special form or a generic alias with that origin
36+
"""
37+
return t is form or (is_generic_alias(t) and get_origin(t) is form) # type: ignore [comparison-overlap]
2938

3039

3140
def is_generic_alias(t: Any) -> TypeGuard[GenericAlias]:
@@ -142,12 +151,12 @@ def is_eval_proxy(t: Any) -> TypeGuard[type[_eval_typing._EvalProxy]]:
142151

143152
__all__ = (
144153
"is_annotated",
145-
"is_classvar",
146154
"is_forward_ref",
147155
"is_generic_alias",
148156
"is_generic_type_alias",
149157
"is_literal",
150158
"is_optional_type",
159+
"is_special_form",
151160
"is_type_alias",
152161
"is_union_type",
153162
)

0 commit comments

Comments
 (0)