Skip to content

Commit 976dc10

Browse files
committed
Updates to include further testing
Changes ======= * PYCO-58: Timeout tests * PYCO-57: Cancellation tests
1 parent 102adbc commit 976dc10

11 files changed

Lines changed: 546 additions & 102 deletions

File tree

acouchbase_analytics/protocol/core/_request_context.py

Lines changed: 76 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,14 @@
1010
List,
1111
Optional,
1212
Type,
13+
Union,
1314
TYPE_CHECKING)
1415
from uuid import uuid4
1516

1617
import anyio
17-
from httpx import Response as HttpCoreResponse
18+
from httpx import (Response as HttpCoreResponse,
19+
TimeoutException)
20+
1821

1922
from acouchbase_analytics.protocol.core._anyio_utils import (AsyncBackend,
2023
current_async_library,
@@ -45,8 +48,10 @@ def __init__(self,
4548
# self._response_task: Optional[Task] = None
4649
self._request_state = StreamingState.NotStarted
4750
self._stage_completed: Optional[anyio.Event] = None
48-
self._request_error: Optional[Exception] = None
49-
self._connect_timeout = self._client_adapter.connection_details.get_connect_timeout()
51+
self._request_error: Optional[Union[BaseException, Exception]] = None
52+
connect_timeout = self._client_adapter.connection_details.get_connect_timeout()
53+
self._connect_deadline = get_time() + connect_timeout
54+
self._cancel_scope_deadline_updated = False
5055

5156
@property
5257
def deserializer(self) -> Deserializer:
@@ -61,14 +66,16 @@ def has_stage_completed(self) -> bool:
6166

6267
@property
6368
def okay_to_iterate(self) -> bool:
69+
self._check_cancelled_or_timed_out()
6470
return StreamingState.okay_to_iterate(self._request_state)
6571

6672
@property
6773
def okay_to_stream(self) -> bool:
74+
self._check_cancelled_or_timed_out()
6875
return StreamingState.okay_to_stream(self._request_state)
6976

7077
@property
71-
def request_error(self) -> Optional[Exception]:
78+
def request_error(self) -> Optional[Union[BaseException, Exception]]:
7279
return self._request_error
7380

7481
@property
@@ -81,45 +88,98 @@ def request_state(self, state: StreamingState) -> None:
8188
raise TypeError('request_state must be an instance of StreamingState')
8289
self._request_state = state
8390

84-
# @property
85-
# def stage_completed(self) -> Optional[anyio.Event]:
86-
# return self._stage_completed
87-
8891
@property
8992
def timed_out(self) -> bool:
93+
self._check_cancelled_or_timed_out()
9094
return self._request_state == StreamingState.Timeout
9195

9296
@property
9397
def cancelled(self) -> bool:
94-
return self._request_state == StreamingState.Cancelled
98+
self._check_cancelled_or_timed_out()
99+
return self._request_state in [StreamingState.Cancelled, StreamingState.AsyncCancelledPriorToTimeout]
100+
101+
def _check_cancelled_or_timed_out(self) -> None:
102+
if self._request_state in [StreamingState.Timeout, StreamingState.Cancelled, StreamingState.Error]:
103+
return
104+
105+
if hasattr(self, '_request_deadline') is False:
106+
return
107+
108+
current_time = get_time()
109+
if self._cancel_scope_deadline_updated is False:
110+
timed_out = current_time >= self._connect_deadline
111+
else:
112+
timed_out = current_time >= self._request_deadline
113+
114+
if timed_out:
115+
if self._request_state == StreamingState.Cancelled:
116+
self._request_state = StreamingState.AsyncCancelledPriorToTimeout
117+
else:
118+
self._request_state = StreamingState.Timeout
95119

96120
async def _execute(self, fn: Callable[..., Awaitable[Any]], *args: object) -> None:
97121
await fn(*args)
98122
if self._stage_completed is not None:
99123
self._stage_completed.set()
100124

125+
def _maybe_set_request_error(self,
126+
exc_type: Optional[Type[BaseException]]=None,
127+
exc_val: Optional[BaseException]=None) -> None:
128+
self._check_cancelled_or_timed_out()
129+
# TODO: Do either of these conditions need to be checked? Does _check_cancelled_or_timed_out() already handle this
130+
# if self._taskgroup.cancel_scope.cancelled_caught and get_time() >= self._taskgroup.cancel_scope.deadline:
131+
# if isinstance(exc_val, CancelledError):
132+
if exc_val is None:
133+
return
134+
if not StreamingState.is_timeout_or_cancelled(self._request_state):
135+
# This handles httpx timeouts
136+
if exc_type is not None and issubclass(exc_type, TimeoutException):
137+
self._request_state = StreamingState.Timeout
138+
elif issubclass(type(exc_val), TimeoutException):
139+
self._request_state = StreamingState.Timeout
140+
else:
141+
self._request_state = StreamingState.Error
142+
self._request_error = exc_val
143+
144+
101145
async def _trace_handler(self, event_name: str, _: str) -> None:
102146
if event_name == 'connection.connect_tcp.complete':
103147
# after connection is established, we need to update the cancel_scope deadline to match the query_timeout
104148
self._update_cancel_scope_deadline(self._request_deadline, is_absolute=True)
149+
self._cancel_scope_deadline_updated = True
150+
elif self._cancel_scope_deadline_updated is False and event_name.endswith('send_request_headers.started'):
151+
# if the socket is reused, we won't get the connect_tcp.complete event, so the deadline at the next closest event
152+
self._update_cancel_scope_deadline(self._request_deadline, is_absolute=True)
153+
self._cancel_scope_deadline_updated = True
105154

106155
def _update_cancel_scope_deadline(self, deadline: float, is_absolute: Optional[bool]=False) -> None:
107156
# TODO: confirm scenario of get_time() < self._taskgroup.cancel_scope.deadline is handled by anyio
108-
109157
new_deadline = deadline if is_absolute else get_time() + deadline
158+
# TODO: Useful debug log message
159+
# print(f'Updating cancel scope deadline: {self._taskgroup.cancel_scope.deadline} -> {new_deadline}')
110160
if get_time() >= new_deadline:
111161
self._taskgroup.cancel_scope.cancel()
112162
else:
113163
self._taskgroup.cancel_scope.deadline = new_deadline
114164

165+
def cancel_request(self,
166+
fn: Optional[Callable[..., Awaitable[Any]]]=None,
167+
*args: object) -> None:
168+
if fn is not None:
169+
self._taskgroup.start_soon(fn, *args)
170+
if self._request_state == StreamingState.Timeout:
171+
return
172+
self._taskgroup.cancel_scope.cancel()
173+
self._request_state = StreamingState.Cancelled
174+
115175
async def initialize(self) -> None:
116176
await self.__aenter__()
117177
self._request_state = StreamingState.Started
118178
# we set the request timeout once the context is initialized in order to create the deadline
119179
# closer to when the upstream logic will begin to use the request context
120180
timeouts = self._request.get_request_timeouts() or {}
121181
self._request_deadline = get_time() + (timeouts.get('read', None) or DEFAULT_TIMEOUTS['query_timeout'])
122-
self._update_cancel_scope_deadline(self._connect_timeout)
182+
self._update_cancel_scope_deadline(self._connect_deadline, is_absolute=True)
123183

124184
async def send_request(self, enable_trace_handling: Optional[bool]=False) -> HttpCoreResponse:
125185
ip = await get_request_ip_async(self._request.host, self._request.port, self._request.previous_ips)
@@ -133,6 +193,7 @@ async def send_request(self, enable_trace_handling: Optional[bool]=False) -> Htt
133193
.update_previous_ips(ip))
134194
else:
135195
self._request.update_url(ip, self._client_adapter.analytics_path).update_previous_ips(ip)
196+
# TODO: add logging; provide request details (to/from, deadlines, etc.)
136197
response = await self._client_adapter.send_request(self._request)
137198
self._request.set_client_server_addrs(response)
138199
if response.status_code == 401:
@@ -150,10 +211,8 @@ async def shutdown(self,
150211
exc_tb: Optional[TracebackType]=None) -> None:
151212
if hasattr(self, '_taskgroup'):
152213
await self.__aexit__(exc_type, exc_val, exc_tb)
153-
elif isinstance(exc_val, CancelledError):
154-
self._request_state = StreamingState.Cancelled
155-
elif exc_val is not None:
156-
self._request_state = StreamingState.Error
214+
else:
215+
self._maybe_set_request_error()
157216

158217
if StreamingState.is_okay(self._request_state):
159218
self._request_state = StreamingState.Completed
@@ -168,7 +227,7 @@ def create_response_task(self, fn: Callable[..., Coroutine[Any, Any, Any]], *arg
168227
task: Task[Any] = self._backend.loop.create_task(fn(*args), name=task_name)
169228
# TODO: I don't think this callback is necessary...need to add more tests to confirm
170229
def task_done(t: Task[Any]) -> None:
171-
print(f'Task ({t.get_name()}) done: {t.done()}, cancelled: {t.cancelled()}')
230+
print(f'Task done callback task=({t.get_name()}); done: {t.done()}, cancelled: {t.cancelled()}')
172231

173232
task.add_done_callback(task_done)
174233
self._response_task = task
@@ -181,9 +240,6 @@ def start_next_stage(self,
181240
fn: Callable[..., Awaitable[Any]],
182241
*args: object,
183242
reset_previous_stage: Optional[bool]=False) -> None:
184-
# if reset_previous_stage is True:
185-
# if self._stage_completed is not None:
186-
# self._stage_completed = None
187243
if self._stage_completed is not None:
188244
if reset_previous_stage is True:
189245
self._stage_completed = None
@@ -222,12 +278,7 @@ async def __aexit__(self,
222278
except BaseException as ex:
223279
pass # we handle the error when the context is shutdown (which is what calls __aexit__())
224280
finally:
225-
if self._taskgroup.cancel_scope.cancelled_caught and get_time() >= self._taskgroup.cancel_scope.deadline:
226-
self._request_state = StreamingState.Timeout
227-
elif isinstance(exc_val, CancelledError):
228-
self._request_state = StreamingState.Cancelled
229-
elif exc_val is not None:
230-
self._request_state = StreamingState.Error
281+
self._maybe_set_request_error()
231282
del self._taskgroup
232283
# TODO: should we suppress here (e.g., return True)
233284
return None

acouchbase_analytics/protocol/streaming.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,13 @@ async def wrapped_fn(self: AsyncHttpStreamingResponse) -> None:
7272
raise ex
7373
except BaseException as ex:
7474
await self._request_context.shutdown(type(ex), ex, ex.__traceback__)
75-
if self._request_context.request_error is not None:
76-
raise self._request_context.request_error from None
7775
if self._request_context.timed_out:
78-
raise TimeoutError(message='Request timed out.') from None
76+
raise TimeoutError(cause=self._request_context.request_error,
77+
message='Request timed out.') from None
7978
if self._request_context.cancelled:
8079
raise CancelledError('Request was cancelled.') from None
80+
if self._request_context.request_error is not None:
81+
raise self._request_context.request_error from None
8182
raise InternalSDKError(ex) from None
8283
finally:
8384
if not StreamingState.is_okay(self._request_context.request_state):
@@ -106,6 +107,15 @@ async def _finish_processing_stream(self) -> None:
106107
self._request_context.start_next_stage(self._json_stream.continue_parsing, reset_previous_stage=True)
107108
await self._request_context.wait_for_stage_to_complete()
108109

110+
async def _handle_iteration_abort(self) -> None:
111+
await self.close()
112+
if self._request_context.cancelled:
113+
raise StopAsyncIteration
114+
elif self._request_context.timed_out:
115+
raise TimeoutError(message='Request timeout.')
116+
else:
117+
raise StopAsyncIteration
118+
109119
def _maybe_continue_to_process_stream(self) -> None:
110120
if not self._request_context.has_stage_completed:
111121
return
@@ -124,12 +134,13 @@ async def _process_response(self, raw_response: Optional[ParsedResult]=None) ->
124134
if raw_response.value is None:
125135
raise AnalyticsError(message='Received unexpected empty result from JsonStream.')
126136

137+
# we have all the data, close the core response/stream
138+
await self.close()
139+
127140
json_response = json.loads(raw_response.value)
128141
if 'errors' in json_response:
129142
await self._request_context.process_error(json_response['errors'])
130143
self.set_metadata(json_data=json_response)
131-
# we have all the data, close the core response/stream
132-
await self.close()
133144

134145
def _start(self) -> None:
135146
"""
@@ -142,6 +153,9 @@ def _start(self) -> None:
142153
self._json_stream = AsyncJsonStream(self._core_response.aiter_bytes(), stream_config=self._stream_config)
143154
self._request_context.start_next_stage(self._json_stream.start_parsing)
144155

156+
async def _close_in_background(self) -> None:
157+
await self.close()
158+
145159
async def close(self) -> None:
146160
"""
147161
**INTERNAL**
@@ -150,14 +164,26 @@ async def close(self) -> None:
150164
await self._core_response.aclose()
151165
del self._core_response
152166

153-
async def cancel(self) -> None:
167+
def cancel(self) -> None:
168+
"""
169+
**INTERNAL**
170+
"""
171+
self._request_context.cancel_request(self._close_in_background)
172+
173+
async def cancel_async(self) -> None:
154174
"""
155175
**INTERNAL**
156176
"""
157177
await self.close()
178+
self._request_context.cancel_request()
179+
await self._request_context.shutdown()
158180

159181
def get_metadata(self) -> QueryMetadata:
160182
if self._metadata is None:
183+
if self._request_context.cancelled:
184+
raise CancelledError('Request was cancelled.')
185+
elif self._request_context.timed_out:
186+
raise TimeoutError(message='Request timeout.')
161187
raise RuntimeError('Query metadata is only available after all rows have been iterated.')
162188
return self._metadata
163189

@@ -175,8 +201,10 @@ async def get_next_row(self) -> Any:
175201
"""
176202
**INTERNAL**
177203
"""
178-
if self._core_response is None or not self._request_context.okay_to_iterate:
179-
raise StopAsyncIteration
204+
if not (hasattr(self, '_core_response')
205+
and self._core_response is not None
206+
and self._request_context.okay_to_iterate):
207+
await self._handle_iteration_abort()
180208

181209
self._maybe_continue_to_process_stream()
182210
raw_response = await self._json_stream.get_result()
@@ -211,6 +239,10 @@ async def send_request(self) -> None:
211239
await self._finish_processing_stream()
212240
await self._process_response()
213241

242+
async def shutdown(self) -> None:
243+
await self.close()
244+
await self._request_context.shutdown()
245+
214246
SendRequestFunc: TypeAlias = Callable[[AsyncHttpStreamingResponse], Coroutine[Any, Any, None]]
215247
# Although, SendRequestFunc is the same type as WrappedSendRequestFunc, keep separate for clarity and indicate
216248
# WrappedSendRequestFunc is a decorator

acouchbase_analytics/scope.pyi

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414
# limitations under the License.
1515

1616
import sys
17-
from asyncio import Future
18-
from typing import overload
17+
from typing import Awaitable, overload
1918

2019
if sys.version_info < (3, 11):
2120
from typing_extensions import Unpack
@@ -33,36 +32,36 @@ class AsyncScope:
3332
def name(self) -> str: ...
3433

3534
@overload
36-
def execute_query(self, statement: str) -> Future[AsyncQueryResult]: ...
35+
def execute_query(self, statement: str) -> Awaitable[AsyncQueryResult]: ...
3736

3837
@overload
39-
def execute_query(self, statement: str, options: QueryOptions) -> Future[AsyncQueryResult]: ...
38+
def execute_query(self, statement: str, options: QueryOptions) -> Awaitable[AsyncQueryResult]: ...
4039

4140
@overload
42-
def execute_query(self, statement: str, **kwargs: Unpack[QueryOptionsKwargs]) -> Future[AsyncQueryResult]: ...
41+
def execute_query(self, statement: str, **kwargs: Unpack[QueryOptionsKwargs]) -> Awaitable[AsyncQueryResult]: ...
4342

4443
@overload
4544
def execute_query(self,
4645
statement: str,
4746
options: QueryOptions,
48-
**kwargs: Unpack[QueryOptionsKwargs]) -> Future[AsyncQueryResult]: ...
47+
**kwargs: Unpack[QueryOptionsKwargs]) -> Awaitable[AsyncQueryResult]: ...
4948

5049
@overload
5150
def execute_query(self,
5251
statement: str,
5352
options: QueryOptions,
5453
*args: str,
55-
**kwargs: Unpack[QueryOptionsKwargs]) -> Future[AsyncQueryResult]: ...
54+
**kwargs: Unpack[QueryOptionsKwargs]) -> Awaitable[AsyncQueryResult]: ...
5655

5756
@overload
5857
def execute_query(self,
5958
statement: str,
6059
options: QueryOptions,
6160
*args: str,
62-
**kwargs: str) -> Future[AsyncQueryResult]: ...
61+
**kwargs: str) -> Awaitable[AsyncQueryResult]: ...
6362

6463
@overload
6564
def execute_query(self,
6665
statement: str,
6766
*args: str,
68-
**kwargs: str) -> Future[AsyncQueryResult]: ...
67+
**kwargs: str) -> Awaitable[AsyncQueryResult]: ...

0 commit comments

Comments
 (0)