Skip to content

Commit 175b6bc

Browse files
committed
fix: enhance type casting for response headers handling in Secure class
1 parent 5f719a7 commit 175b6bc

1 file changed

Lines changed: 22 additions & 16 deletions

File tree

secure/secure.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77
import logging
88
import re
99
from types import MappingProxyType
10-
from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias
10+
from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias, cast
1111

1212
if TYPE_CHECKING:
13-
from collections.abc import Iterable, Mapping
13+
from collections.abc import Callable, Iterable, Mapping, MutableMapping
1414

1515
from .headers import (
1616
BaseHeader,
@@ -84,8 +84,8 @@ class HeaderSetError(RuntimeError):
8484

8585

8686
class HeadersProtocol(Protocol):
87-
# Intentionally broad: frameworks type headers differently.
88-
headers: object
87+
@property
88+
def headers(self) -> object: ...
8989

9090

9191
class SetHeaderProtocol(Protocol):
@@ -853,19 +853,20 @@ def set_headers(self, response: ResponseProtocol) -> None: # noqa: PLR0912
853853

854854
# Path 2: response.headers...
855855
if hasattr(response, "headers"):
856-
hdrs = response.headers
856+
hdrs = cast("object", response.headers)
857857

858-
# Prefer Werkzeug-style: response.headers.set(name, value)
859858
set_fn = getattr(hdrs, "set", None)
860859
if callable(set_fn):
861-
if inspect.iscoroutinefunction(set_fn):
860+
set_fn_typed = cast("Callable[[str, str], object]", set_fn)
861+
862+
if inspect.iscoroutinefunction(set_fn_typed):
862863
raise RuntimeError(
863864
"Async headers setter detected in sync context. Use 'await set_headers_async(response)'."
864865
)
865866

866867
try:
867868
for name, value in items:
868-
result = set_fn(name, value)
869+
result = set_fn_typed(name, value)
869870
if inspect.isawaitable(result):
870871
raise RuntimeError(
871872
"Async headers setter returned awaitable in sync context. "
@@ -876,17 +877,20 @@ def set_headers(self, response: ResponseProtocol) -> None: # noqa: PLR0912
876877

877878
return
878879

879-
# Fallback: response.headers[name] = value # noqa: ERA001
880880
setitem = getattr(hdrs, "__setitem__", None)
881881
if callable(setitem):
882-
if inspect.iscoroutinefunction(setitem):
882+
setitem_typed = cast("Callable[[str, str], object]", setitem)
883+
884+
if inspect.iscoroutinefunction(setitem_typed):
883885
raise RuntimeError(
884886
"Async headers mapping detected in sync context. Use 'await set_headers_async(response)'."
885887
)
886888

887889
try:
890+
# Use mapping assignment for the common case.
891+
hdrs_map = cast("MutableMapping[str, str]", hdrs)
888892
for name, value in items:
889-
hdrs[name] = value
893+
hdrs_map[name] = value
890894
except (TypeError, ValueError, AttributeError) as e:
891895
raise HeaderSetError(f"Failed to set headers: {e}") from e
892896

@@ -944,27 +948,29 @@ async def set_headers_async(self, response: ResponseProtocol) -> None: # noqa:
944948

945949
# Path 2: response.headers...
946950
if hasattr(response, "headers"):
947-
hdrs = response.headers
951+
hdrs = cast("object", response.headers)
948952

949-
# Prefer Werkzeug-style: response.headers.set(name, value)
950953
set_fn = getattr(hdrs, "set", None)
951954
if callable(set_fn):
955+
set_fn_typed = cast("Callable[[str, str], object]", set_fn)
956+
952957
try:
953958
for name, value in items:
954-
result = set_fn(name, value)
959+
result = set_fn_typed(name, value)
955960
if inspect.isawaitable(result):
956961
await result
957962
except (TypeError, ValueError, AttributeError) as e:
958963
raise HeaderSetError(f"Failed to set headers: {e}") from e
959964

960965
return
961966

962-
# Fallback: response.headers.__setitem__(name, value) # noqa: ERA001
963967
setitem = getattr(hdrs, "__setitem__", None)
964968
if callable(setitem):
969+
setitem_typed = cast("Callable[[str, str], object]", setitem)
970+
965971
try:
966972
for name, value in items:
967-
result = setitem(name, value)
973+
result = setitem_typed(name, value)
968974
if inspect.isawaitable(result):
969975
await result
970976
except (TypeError, ValueError, AttributeError) as e:

0 commit comments

Comments
 (0)