-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathrequest_context.py
More file actions
459 lines (400 loc) · 19.6 KB
/
request_context.py
File metadata and controls
459 lines (400 loc) · 19.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
# Copyright 2016-2025. Couchbase, Inc.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import json
import math
from asyncio import CancelledError, Task
from types import TracebackType
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Coroutine, Dict, List, Optional, Type, Union
from uuid import uuid4
import anyio
from httpx import Response as HttpCoreResponse
from httpx import TimeoutException
from acouchbase_analytics.protocol._core.anyio_utils import AsyncBackend, current_async_library, get_time
from acouchbase_analytics.protocol._core.async_json_stream import AsyncJsonStream
from acouchbase_analytics.protocol._core.net_utils import get_request_ip_async
from couchbase_analytics.common._core import JsonStreamConfig, ParsedResult, ParsedResultType
from couchbase_analytics.common._core.error_context import ErrorContext
from couchbase_analytics.common.backoff_calculator import DefaultBackoffCalculator
from couchbase_analytics.common.errors import AnalyticsError
from couchbase_analytics.common.logging import LogLevel
from couchbase_analytics.common.request import RequestState
from couchbase_analytics.protocol.connection import DEFAULT_TIMEOUTS
from couchbase_analytics.protocol.errors import ErrorMapper
if TYPE_CHECKING:
from acouchbase_analytics.protocol._core.client_adapter import _AsyncClientAdapter
from couchbase_analytics.protocol._core.request import QueryRequest
class AsyncRequestContext:
def __init__(
self,
client_adapter: _AsyncClientAdapter,
request: QueryRequest,
stream_config: Optional[JsonStreamConfig] = None,
backend: Optional[AsyncBackend] = None,
) -> None:
self._id = str(uuid4())
self._client_adapter = client_adapter
self._request = request
self._backend = backend or current_async_library()
self._backoff_calc = DefaultBackoffCalculator()
self._error_ctx = ErrorContext(num_attempts=0, method=request.method, statement=request.get_request_statement())
self._request_state = RequestState.NotStarted
self._stream_config = stream_config or JsonStreamConfig()
self._json_stream: AsyncJsonStream
self._stage_completed: Optional[anyio.Event] = None
self._request_error: Optional[Union[BaseException, Exception]] = None
connect_timeout = self._client_adapter.connection_details.get_connect_timeout()
self._connect_deadline = get_time() + connect_timeout
self._cancel_scope_deadline_updated = False
self._shutdown = False
self._request_deadline = math.inf
@property
def cancelled(self) -> bool:
self._check_timed_out()
return self._request_state in [RequestState.Cancelled, RequestState.AsyncCancelledPriorToTimeout]
@property
def error_context(self) -> ErrorContext:
return self._error_ctx
@property
def has_stage_completed(self) -> bool:
return self._stage_completed is not None and self._stage_completed.is_set()
@property
def is_shutdown(self) -> bool:
return self._shutdown
@property
def okay_to_iterate(self) -> bool:
self._check_timed_out()
return RequestState.okay_to_iterate(self._request_state)
@property
def okay_to_stream(self) -> bool:
self._check_timed_out()
return RequestState.okay_to_stream(self._request_state)
@property
def request_error(self) -> Optional[Union[BaseException, Exception]]:
return self._request_error
@property
def request_state(self) -> RequestState:
return self._request_state
@property
def retry_limit_exceeded(self) -> bool:
return self.error_context.num_attempts > self._request.max_retries
@property
def results_or_errors_type(self) -> ParsedResultType:
return self._json_stream.results_or_errors_type
@property
def timed_out(self) -> bool:
self._check_timed_out()
return self._request_state == RequestState.Timeout
def _check_timed_out(self) -> None:
if self._request_state in [RequestState.Timeout, RequestState.Cancelled, RequestState.Error]:
return
if hasattr(self, '_request_deadline') is False:
return
current_time = get_time()
timed_out = current_time >= self._request_deadline
if timed_out:
message_data = {'current_time': f'{current_time}', 'request_deadline': f'{self._request_deadline}'}
self.log_message('Request has timed out', LogLevel.DEBUG, message_data=message_data)
if self._request_state == RequestState.Cancelled:
self._request_state = RequestState.AsyncCancelledPriorToTimeout
else:
self._request_state = RequestState.Timeout
async def _execute(self, fn: Callable[..., Awaitable[Any]], *args: object) -> None:
await fn(*args)
if self._stage_completed is not None:
self._stage_completed.set()
def _maybe_set_request_error(
self, exc_type: Optional[Type[BaseException]] = None, exc_val: Optional[BaseException] = None
) -> None:
self._check_timed_out()
if exc_val is None:
return
if not RequestState.is_timeout_or_cancelled(self._request_state):
# This handles httpx timeouts
if exc_type is not None and issubclass(exc_type, TimeoutException):
self._request_state = RequestState.Timeout
elif issubclass(type(exc_val), TimeoutException):
self._request_state = RequestState.Timeout
elif isinstance(exc_val, CancelledError):
self._request_state = RequestState.Cancelled
else:
self._request_state = RequestState.Error
self._request_error = exc_val
async def _process_error(
self, json_data: Union[str, List[Dict[str, Any]]], handle_context_shutdown: Optional[bool] = False
) -> None:
self._request_state = RequestState.Error
if isinstance(json_data, str):
self._request_error = ErrorMapper.build_error_from_http_status_code(json_data, self._error_ctx)
elif not isinstance(json_data, list):
self._request_error = AnalyticsError(
'Cannot parse error response; expected JSON array', context=str(self._error_ctx)
)
else:
self._request_error = ErrorMapper.build_error_from_json(json_data, self._error_ctx)
if handle_context_shutdown is True:
await self.reraise_after_shutdown(self._request_error)
raise self._request_error
def _reset_stream(self) -> None:
if hasattr(self, '_json_stream'):
del self._json_stream
self._request_state = RequestState.ResetAndNotStarted
self._stage_completed = None
self._cancel_scope_deadline_updated = False
def _start_next_stage(
self, fn: Callable[..., Awaitable[Any]], *args: object, reset_previous_stage: Optional[bool] = False
) -> None:
if self._stage_completed is not None:
if reset_previous_stage is True:
self._stage_completed = None
else:
raise RuntimeError('Task already running in this context.')
self._stage_completed = anyio.Event()
self._taskgroup.start_soon(self._execute, fn, *args)
async def _trace_handler(self, event_name: str, _: str) -> None:
if event_name == 'connection.connect_tcp.complete':
# after connection is established, we need to update the cancel_scope deadline to match the query_timeout
self._update_cancel_scope_deadline(self._request_deadline, is_absolute=True)
self._cancel_scope_deadline_updated = True
elif self._cancel_scope_deadline_updated is False and event_name.endswith('send_request_headers.started'):
# if the socket is reused, we won't get the connect_tcp.complete event,
# so the deadline at the next closest event
self._update_cancel_scope_deadline(self._request_deadline, is_absolute=True)
self._cancel_scope_deadline_updated = True
def _update_cancel_scope_deadline(self, deadline: float, is_absolute: Optional[bool] = False) -> None:
new_deadline = deadline if is_absolute else get_time() + deadline
current_time = get_time()
if current_time >= new_deadline:
self.log_message(
'Deadline already exceeded, cancelling request',
LogLevel.DEBUG,
message_data={
'current_time': f'{current_time}',
'new_deadline': f'{new_deadline}',
},
)
self._taskgroup.cancel_scope.cancel()
else:
self.log_message(
f'Updating cancel scope deadline: {self._taskgroup.cancel_scope.deadline} -> {new_deadline}',
LogLevel.DEBUG,
)
self._taskgroup.cancel_scope.deadline = new_deadline
async def _wait_for_stage_to_complete(self) -> None:
if self._stage_completed is None:
return
await self._stage_completed.wait()
def calculate_backoff(self) -> float:
return self._backoff_calc.calculate_backoff(self._error_ctx.num_attempts) / 1000
def cancel_request(self, fn: Optional[Callable[..., Awaitable[Any]]] = None, *args: object) -> None:
if fn is not None:
self._taskgroup.start_soon(fn, *args)
if self._request_state == RequestState.Timeout:
return
self._taskgroup.cancel_scope.cancel()
self._request_state = RequestState.Cancelled
def create_response_task(self, fn: Callable[..., Coroutine[Any, Any, Any]], *args: object) -> Task[Any]:
if self._backend is None or self._backend.backend_lib != 'asyncio':
raise RuntimeError('Must use the asyncio backend to create a response task.')
if self._backend.loop is None:
raise RuntimeError('Async backend loop is not initialized.')
task_name = f'{self._id}-response-task'
task: Task[Any] = self._backend.loop.create_task(fn(*args), name=task_name)
self._response_task = task
return task
def deserialize_result(self, result: bytes) -> Any:
return self._request.deserializer.deserialize(result)
async def finish_processing_stream(self) -> None:
if not self.has_stage_completed:
await self._wait_for_stage_to_complete()
while not self._json_stream.token_stream_exhausted:
self._start_next_stage(self._json_stream.continue_parsing, reset_previous_stage=True)
await self._wait_for_stage_to_complete()
async def get_result_from_stream(self) -> ParsedResult:
return await self._json_stream.get_result()
async def initialize(self) -> None:
if self._request_state == RequestState.ResetAndNotStarted:
current_time = get_time()
self.log_message(
'Request is a retry, skipping initialization',
LogLevel.DEBUG,
message_data={'current_time': f'{current_time}', 'request_deadline': f'{self._request_deadline}'},
)
return
await self.__aenter__()
self._request_state = RequestState.Started
# we set the request timeout once the context is initialized in order to create the deadline
# closer to when the upstream logic will begin to use the request context
timeouts = self._request.get_request_timeouts() or {}
current_time = get_time()
self._request_deadline = current_time + (timeouts.get('read', None) or DEFAULT_TIMEOUTS['query_timeout'])
message_data = {'current_time': f'{current_time}', 'request_deadline': f'{self._request_deadline}'}
self.log_message('Request context initialized', LogLevel.DEBUG, message_data=message_data)
def log_message(
self,
message: str,
log_level: LogLevel,
message_data: Optional[Dict[str, str]] = None,
append_ctx: Optional[bool] = True,
) -> None:
if append_ctx is True:
message = f'{message}: ctx={self._id}'
if message_data is not None:
message_data_str = ', '.join(f'{k}={v}' for k, v in message_data.items())
message = f'{message}, {message_data_str}'
self._client_adapter.log_message(message, log_level)
def maybe_continue_to_process_stream(self) -> None:
if not self.has_stage_completed:
return
if self._json_stream.token_stream_exhausted:
return
self._start_next_stage(self._json_stream.continue_parsing, reset_previous_stage=True)
def okay_to_delay_and_retry(self, delay: float) -> bool:
self._check_timed_out()
if self._request_state in [RequestState.Timeout, RequestState.Cancelled]:
return False
current_time = get_time()
delay_time = current_time + delay
will_time_out = self._request_deadline < delay_time
if will_time_out:
self._request_state = RequestState.Timeout
message_data = {
'current_time': f'{current_time}',
'delay_time': f'{delay_time}',
'request_deadline': f'{self._request_deadline}',
}
self.log_message('Request will timeout after delay', LogLevel.DEBUG, message_data=message_data)
return False
elif self.retry_limit_exceeded:
self._request_state = RequestState.Error
message_data = {
'num_attempts': f'{self.error_context.num_attempts}',
'max_retries': f'{self._request.max_retries}',
}
self.log_message('Request has exceeded max retries', LogLevel.DEBUG, message_data=message_data)
return False
else:
self._reset_stream()
return True
async def process_response(
self,
close_handler: Callable[[], Coroutine[Any, Any, None]],
raw_response: Optional[ParsedResult] = None,
handle_context_shutdown: Optional[bool] = False,
) -> Any:
if raw_response is None:
raw_response = await self._json_stream.get_result()
if raw_response is None:
await close_handler()
raise AnalyticsError(
message='Received unexpected empty result from JsonStream.', context=str(self._error_ctx)
)
if raw_response.value is None:
await close_handler()
raise AnalyticsError(
message='Received unexpected empty result from JsonStream.', context=str(self._error_ctx)
)
# we have all the data, close the core response/stream
await close_handler()
try:
json_response = json.loads(raw_response.value)
except json.JSONDecodeError:
await self._process_error(str(raw_response.value), handle_context_shutdown=handle_context_shutdown)
else:
if 'errors' in json_response:
await self._process_error(json_response['errors'], handle_context_shutdown=handle_context_shutdown)
return json_response
async def reraise_after_shutdown(self, err: Exception) -> None:
try:
raise err
except Exception as ex:
await self.shutdown(type(ex), ex, ex.__traceback__)
raise ex from None
async def send_request(self, enable_trace_handling: Optional[bool] = False) -> HttpCoreResponse:
self._error_ctx.update_num_attempts()
ip = await get_request_ip_async(self._request.url.host, self._request.url.port, self.log_message)
if enable_trace_handling is True:
(
self._request.update_url(ip, self._client_adapter.analytics_path).add_trace_to_extensions(
self._trace_handler
)
)
else:
self._request.update_url(ip, self._client_adapter.analytics_path)
self._error_ctx.update_request_context(self._request)
message_data = {
'url': f'{self._request.url.get_formatted_url()}',
'body': f'{self._request.body}',
'request_deadline': f'{self._request_deadline}',
}
self.log_message('HTTP request', LogLevel.DEBUG, message_data=message_data)
response = await self._client_adapter.send_request(self._request)
self._error_ctx.update_response_context(response)
message_data = {
'status_code': f'{response.status_code}',
'last_dispatched_to': f'{self._error_ctx.last_dispatched_to}',
'last_dispatched_from': f'{self._error_ctx.last_dispatched_from}',
'request_deadline': f'{self._request_deadline}',
}
self.log_message('HTTP response', LogLevel.DEBUG, message_data=message_data)
return response
async def shutdown(
self,
exc_type: Optional[Type[BaseException]] = None,
exc_val: Optional[BaseException] = None,
exc_tb: Optional[TracebackType] = None,
) -> None:
if self.is_shutdown:
self.log_message('Request context already shutdown', LogLevel.WARNING)
return
if hasattr(self, '_taskgroup'):
await self.__aexit__(exc_type, exc_val, exc_tb)
else:
self._maybe_set_request_error(exc_type, exc_val)
if RequestState.is_okay(self._request_state):
self._request_state = RequestState.Completed
self._shutdown = True
self.log_message('Request context shutdown complete', LogLevel.INFO)
def start_stream(self, core_response: HttpCoreResponse) -> None:
if hasattr(self, '_json_stream'):
self.log_message('JSON stream already exists', LogLevel.WARNING)
return
self._json_stream = AsyncJsonStream(
core_response.aiter_bytes(), stream_config=self._stream_config, logger_handler=self.log_message
)
self._start_next_stage(self._json_stream.start_parsing)
async def wait_for_results_or_errors(self) -> None:
await self._json_stream.has_results_or_errors.wait()
if self._json_stream.results_or_errors_type == ParsedResultType.ROW:
# we move to iterating rows
self._request_state = RequestState.StreamingResults
async def __aenter__(self) -> AsyncRequestContext:
self._taskgroup = anyio.create_task_group()
message_data = {'cancel_scope': f'{id(self._taskgroup.cancel_scope):x}'}
self.log_message('Task group created', LogLevel.DEBUG, message_data=message_data)
await self._taskgroup.__aenter__()
return self
# TODO(PYCO-72): Possible improvement to handling async RequestContext.__aexit__
async def __aexit__(
self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]
) -> Optional[bool]:
try:
await self._taskgroup.__aexit__(exc_type, exc_val, exc_tb)
except BaseException:
pass # we handle the error when the context is shutdown (which is what calls __aexit__())
finally:
self._maybe_set_request_error(exc_type, exc_val)
del self._taskgroup
return None # noqa: B012