1010 List ,
1111 Optional ,
1212 Type ,
13+ Union ,
1314 TYPE_CHECKING )
1415from uuid import uuid4
1516
1617import anyio
17- from httpx import Response as HttpCoreResponse
18+ from httpx import (Response as HttpCoreResponse ,
19+ TimeoutException )
20+
1821
1922from acouchbase_analytics .protocol .core ._anyio_utils import (AsyncBackend ,
2023 current_async_library ,
@@ -45,8 +48,10 @@ def __init__(self,
4548 # self._response_task: Optional[Task] = None
4649 self ._request_state = StreamingState .NotStarted
4750 self ._stage_completed : Optional [anyio .Event ] = None
48- self ._request_error : Optional [Exception ] = None
49- self ._connect_timeout = self ._client_adapter .connection_details .get_connect_timeout ()
51+ self ._request_error : Optional [Union [BaseException , Exception ]] = None
52+ connect_timeout = self ._client_adapter .connection_details .get_connect_timeout ()
53+ self ._connect_deadline = get_time () + connect_timeout
54+ self ._cancel_scope_deadline_updated = False
5055
5156 @property
5257 def deserializer (self ) -> Deserializer :
@@ -61,14 +66,16 @@ def has_stage_completed(self) -> bool:
6166
6267 @property
6368 def okay_to_iterate (self ) -> bool :
69+ self ._check_cancelled_or_timed_out ()
6470 return StreamingState .okay_to_iterate (self ._request_state )
6571
6672 @property
6773 def okay_to_stream (self ) -> bool :
74+ self ._check_cancelled_or_timed_out ()
6875 return StreamingState .okay_to_stream (self ._request_state )
6976
7077 @property
71- def request_error (self ) -> Optional [Exception ]:
78+ def request_error (self ) -> Optional [Union [ BaseException , Exception ] ]:
7279 return self ._request_error
7380
7481 @property
@@ -81,45 +88,98 @@ def request_state(self, state: StreamingState) -> None:
8188 raise TypeError ('request_state must be an instance of StreamingState' )
8289 self ._request_state = state
8390
84- # @property
85- # def stage_completed(self) -> Optional[anyio.Event]:
86- # return self._stage_completed
87-
8891 @property
8992 def timed_out (self ) -> bool :
93+ self ._check_cancelled_or_timed_out ()
9094 return self ._request_state == StreamingState .Timeout
9195
9296 @property
9397 def cancelled (self ) -> bool :
94- return self ._request_state == StreamingState .Cancelled
98+ self ._check_cancelled_or_timed_out ()
99+ return self ._request_state in [StreamingState .Cancelled , StreamingState .AsyncCancelledPriorToTimeout ]
100+
101+ def _check_cancelled_or_timed_out (self ) -> None :
102+ if self ._request_state in [StreamingState .Timeout , StreamingState .Cancelled , StreamingState .Error ]:
103+ return
104+
105+ if hasattr (self , '_request_deadline' ) is False :
106+ return
107+
108+ current_time = get_time ()
109+ if self ._cancel_scope_deadline_updated is False :
110+ timed_out = current_time >= self ._connect_deadline
111+ else :
112+ timed_out = current_time >= self ._request_deadline
113+
114+ if timed_out :
115+ if self ._request_state == StreamingState .Cancelled :
116+ self ._request_state = StreamingState .AsyncCancelledPriorToTimeout
117+ else :
118+ self ._request_state = StreamingState .Timeout
95119
96120 async def _execute (self , fn : Callable [..., Awaitable [Any ]], * args : object ) -> None :
97121 await fn (* args )
98122 if self ._stage_completed is not None :
99123 self ._stage_completed .set ()
100124
125+ def _maybe_set_request_error (self ,
126+ exc_type : Optional [Type [BaseException ]]= None ,
127+ exc_val : Optional [BaseException ]= None ) -> None :
128+ self ._check_cancelled_or_timed_out ()
129+ # TODO: Do either of these conditions need to be checked? Does _check_cancelled_or_timed_out() already handle this
130+ # if self._taskgroup.cancel_scope.cancelled_caught and get_time() >= self._taskgroup.cancel_scope.deadline:
131+ # if isinstance(exc_val, CancelledError):
132+ if exc_val is None :
133+ return
134+ if not StreamingState .is_timeout_or_cancelled (self ._request_state ):
135+ # This handles httpx timeouts
136+ if exc_type is not None and issubclass (exc_type , TimeoutException ):
137+ self ._request_state = StreamingState .Timeout
138+ elif issubclass (type (exc_val ), TimeoutException ):
139+ self ._request_state = StreamingState .Timeout
140+ else :
141+ self ._request_state = StreamingState .Error
142+ self ._request_error = exc_val
143+
144+
101145 async def _trace_handler (self , event_name : str , _ : str ) -> None :
102146 if event_name == 'connection.connect_tcp.complete' :
103147 # after connection is established, we need to update the cancel_scope deadline to match the query_timeout
104148 self ._update_cancel_scope_deadline (self ._request_deadline , is_absolute = True )
149+ self ._cancel_scope_deadline_updated = True
150+ elif self ._cancel_scope_deadline_updated is False and event_name .endswith ('send_request_headers.started' ):
151+ # if the socket is reused, we won't get the connect_tcp.complete event, so the deadline at the next closest event
152+ self ._update_cancel_scope_deadline (self ._request_deadline , is_absolute = True )
153+ self ._cancel_scope_deadline_updated = True
105154
106155 def _update_cancel_scope_deadline (self , deadline : float , is_absolute : Optional [bool ]= False ) -> None :
107156 # TODO: confirm scenario of get_time() < self._taskgroup.cancel_scope.deadline is handled by anyio
108-
109157 new_deadline = deadline if is_absolute else get_time () + deadline
158+ # TODO: Useful debug log message
159+ # print(f'Updating cancel scope deadline: {self._taskgroup.cancel_scope.deadline} -> {new_deadline}')
110160 if get_time () >= new_deadline :
111161 self ._taskgroup .cancel_scope .cancel ()
112162 else :
113163 self ._taskgroup .cancel_scope .deadline = new_deadline
114164
165+ def cancel_request (self ,
166+ fn : Optional [Callable [..., Awaitable [Any ]]]= None ,
167+ * args : object ) -> None :
168+ if fn is not None :
169+ self ._taskgroup .start_soon (fn , * args )
170+ if self ._request_state == StreamingState .Timeout :
171+ return
172+ self ._taskgroup .cancel_scope .cancel ()
173+ self ._request_state = StreamingState .Cancelled
174+
115175 async def initialize (self ) -> None :
116176 await self .__aenter__ ()
117177 self ._request_state = StreamingState .Started
118178 # we set the request timeout once the context is initialized in order to create the deadline
119179 # closer to when the upstream logic will begin to use the request context
120180 timeouts = self ._request .get_request_timeouts () or {}
121181 self ._request_deadline = get_time () + (timeouts .get ('read' , None ) or DEFAULT_TIMEOUTS ['query_timeout' ])
122- self ._update_cancel_scope_deadline (self ._connect_timeout )
182+ self ._update_cancel_scope_deadline (self ._connect_deadline , is_absolute = True )
123183
124184 async def send_request (self , enable_trace_handling : Optional [bool ]= False ) -> HttpCoreResponse :
125185 ip = await get_request_ip_async (self ._request .host , self ._request .port , self ._request .previous_ips )
@@ -133,6 +193,7 @@ async def send_request(self, enable_trace_handling: Optional[bool]=False) -> Htt
133193 .update_previous_ips (ip ))
134194 else :
135195 self ._request .update_url (ip , self ._client_adapter .analytics_path ).update_previous_ips (ip )
196+ # TODO: add logging; provide request details (to/from, deadlines, etc.)
136197 response = await self ._client_adapter .send_request (self ._request )
137198 self ._request .set_client_server_addrs (response )
138199 if response .status_code == 401 :
@@ -150,10 +211,8 @@ async def shutdown(self,
150211 exc_tb : Optional [TracebackType ]= None ) -> None :
151212 if hasattr (self , '_taskgroup' ):
152213 await self .__aexit__ (exc_type , exc_val , exc_tb )
153- elif isinstance (exc_val , CancelledError ):
154- self ._request_state = StreamingState .Cancelled
155- elif exc_val is not None :
156- self ._request_state = StreamingState .Error
214+ else :
215+ self ._maybe_set_request_error ()
157216
158217 if StreamingState .is_okay (self ._request_state ):
159218 self ._request_state = StreamingState .Completed
@@ -168,7 +227,7 @@ def create_response_task(self, fn: Callable[..., Coroutine[Any, Any, Any]], *arg
168227 task : Task [Any ] = self ._backend .loop .create_task (fn (* args ), name = task_name )
169228 # TODO: I don't think this callback is necessary...need to add more tests to confirm
170229 def task_done (t : Task [Any ]) -> None :
171- print (f'Task ({ t .get_name ()} ) done: { t .done ()} , cancelled: { t .cancelled ()} ' )
230+ print (f'Task done callback task= ({ t .get_name ()} ); done: { t .done ()} , cancelled: { t .cancelled ()} ' )
172231
173232 task .add_done_callback (task_done )
174233 self ._response_task = task
@@ -181,9 +240,6 @@ def start_next_stage(self,
181240 fn : Callable [..., Awaitable [Any ]],
182241 * args : object ,
183242 reset_previous_stage : Optional [bool ]= False ) -> None :
184- # if reset_previous_stage is True:
185- # if self._stage_completed is not None:
186- # self._stage_completed = None
187243 if self ._stage_completed is not None :
188244 if reset_previous_stage is True :
189245 self ._stage_completed = None
@@ -222,12 +278,7 @@ async def __aexit__(self,
222278 except BaseException as ex :
223279 pass # we handle the error when the context is shutdown (which is what calls __aexit__())
224280 finally :
225- if self ._taskgroup .cancel_scope .cancelled_caught and get_time () >= self ._taskgroup .cancel_scope .deadline :
226- self ._request_state = StreamingState .Timeout
227- elif isinstance (exc_val , CancelledError ):
228- self ._request_state = StreamingState .Cancelled
229- elif exc_val is not None :
230- self ._request_state = StreamingState .Error
281+ self ._maybe_set_request_error ()
231282 del self ._taskgroup
232283 # TODO: should we suppress here (e.g., return True)
233284 return None
0 commit comments