55from typing import (Any ,
66 Awaitable ,
77 Callable ,
8+ Coroutine ,
89 Dict ,
910 List ,
1011 Optional ,
2021 get_time )
2122from couchbase_analytics .common .core .net_utils import get_request_ip_async
2223from couchbase_analytics .common .deserializer import Deserializer
23- from couchbase_analytics .common .errors import AnalyticsError
24+ from couchbase_analytics .common .errors import AnalyticsError , InvalidCredentialError
2425from couchbase_analytics .common .streaming import StreamingState
2526from couchbase_analytics .protocol .connection import DEFAULT_TIMEOUTS
2627from couchbase_analytics .protocol .errors import ErrorMapper
@@ -41,7 +42,7 @@ def __init__(self,
4142 self ._client_adapter = client_adapter
4243 self ._request = request
4344 self ._backend = backend or current_async_library ()
44- self ._response_task : Optional [Task ] = None
45+ # self._response_task: Optional[Task] = None
4546 self ._request_state = StreamingState .NotStarted
4647 self ._stage_completed : Optional [anyio .Event ] = None
4748 self ._request_error : Optional [Exception ] = None
@@ -80,9 +81,9 @@ def request_state(self, state: StreamingState) -> None:
8081 raise TypeError ('request_state must be an instance of StreamingState' )
8182 self ._request_state = state
8283
83- @property
84- def stage_completed (self ) -> anyio .Event :
85- return self ._stage_completed
84+ # @property
85+ # def stage_completed(self) -> Optional[ anyio.Event] :
86+ # return self._stage_completed
8687
8788 @property
8889 def timed_out (self ) -> bool :
@@ -94,9 +95,10 @@ def cancelled(self) -> bool:
9495
9596 async def _execute (self , fn : Callable [..., Awaitable [Any ]], * args : object ) -> None :
9697 await fn (* args )
97- self ._stage_completed .set ()
98+ if self ._stage_completed is not None :
99+ self ._stage_completed .set ()
98100
99- async def _trace_handler (self , event_name , _ ) -> None :
101+ async def _trace_handler (self , event_name : str , _ : str ) -> None :
100102 if event_name == 'connection.connect_tcp.complete' :
101103 # after connection is established, we need to update the cancel_scope deadline to match the query_timeout
102104 self ._update_cancel_scope_deadline (self ._request_deadline , is_absolute = True )
@@ -115,24 +117,31 @@ async def initialize(self) -> None:
115117 self ._request_state = StreamingState .Started
116118 # we set the request timeout once the context is initialized in order to create the deadline
117119 # closer to when the upstream logic will begin to use the request context
118- timeouts = self ._request .get_request_timeouts ()
119- self ._request_deadline = get_time () + timeouts .get ('read' , DEFAULT_TIMEOUTS ['query_timeout' ])
120+ timeouts = self ._request .get_request_timeouts () or {}
121+ self ._request_deadline = get_time () + ( timeouts .get ('read' , None ) or DEFAULT_TIMEOUTS ['query_timeout' ])
120122 self ._update_cancel_scope_deadline (self ._connect_timeout )
121123
122124 async def send_request (self , enable_trace_handling : Optional [bool ]= False ) -> HttpCoreResponse :
123125 ip = await get_request_ip_async (self ._request .host , self ._request .port , self ._request .previous_ips )
124126 if ip is None :
125127 attempted_ips = ', ' .join (self ._request .previous_ips or [])
126- raise AnalyticsError (f'Connect failure. Attempted to connect to resolved IPs: { attempted_ips } .' )
128+ raise AnalyticsError (message = f'Connect failure. Attempted to connect to resolved IPs: { attempted_ips } .' )
127129
128130 if enable_trace_handling is True :
129131 (self ._request .update_url (ip , self ._client_adapter .analytics_path )
130- .update_extensions ({ 'trace' : self ._trace_handler } )
132+ .add_trace_to_extensions ( self ._trace_handler )
131133 .update_previous_ips (ip ))
132134 else :
133135 self ._request .update_url (ip , self ._client_adapter .analytics_path ).update_previous_ips (ip )
134136 response = await self ._client_adapter .send_request (self ._request )
135137 self ._request .set_client_server_addrs (response )
138+ if response .status_code == 401 :
139+ context = {
140+ 'client_addr' : self ._request .client_addr ,
141+ 'server_addr' : self ._request .server_addr ,
142+ 'http_status' : response .status_code ,
143+ }
144+ raise InvalidCredentialError (str (context ))
136145 return response
137146
138147 async def shutdown (self ,
@@ -149,14 +158,16 @@ async def shutdown(self,
149158 if StreamingState .is_okay (self ._request_state ):
150159 self ._request_state = StreamingState .Completed
151160
152- def create_response_task (self , fn : Callable [..., Awaitable [Any ]], * args : object ) -> Task :
161+ def create_response_task (self , fn : Callable [..., Coroutine [Any , Any , Any ]], * args : object ) -> Task [ Any ] :
153162 if self ._backend is None or self ._backend .backend_lib != 'asyncio' :
154163 raise RuntimeError ('Must use the asyncio backend to create a response task.' )
164+ if self ._backend .loop is None :
165+ raise RuntimeError ('Async backend loop is not initialized.' )
155166 task_name = f'{ self ._id } -response-task'
156167 print (f'Creating response task: { task_name } ' )
157- task = self ._backend .loop .create_task (fn (* args ), name = task_name )
168+ task : Task [ Any ] = self ._backend .loop .create_task (fn (* args ), name = task_name )
158169 # TODO: I don't think this callback is necessary...need to add more tests to confirm
159- def task_done (t : Task ) -> None :
170+ def task_done (t : Task [ Any ] ) -> None :
160171 print (f'Task ({ t .get_name ()} ) done: { t .done ()} , cancelled: { t .cancelled ()} ' )
161172
162173 task .add_done_callback (task_done )
@@ -170,15 +181,23 @@ def start_next_stage(self,
170181 fn : Callable [..., Awaitable [Any ]],
171182 * args : object ,
172183 reset_previous_stage : Optional [bool ]= False ) -> None :
173- if reset_previous_stage is True :
174- if self ._stage_completed is not None :
184+ # if reset_previous_stage is True:
185+ # if self._stage_completed is not None:
186+ # self._stage_completed = None
187+ if self ._stage_completed is not None :
188+ if reset_previous_stage is True :
175189 self ._stage_completed = None
176- elif self . _stage_completed is not None :
177- raise RuntimeError ('Task already running in this context.' )
190+ else :
191+ raise RuntimeError ('Task already running in this context.' )
178192
179193 self ._stage_completed = anyio .Event ()
180194 self ._taskgroup .start_soon (self ._execute , fn , * args )
181195
196+ async def wait_for_stage_to_complete (self ) -> None :
197+ if self ._stage_completed is None :
198+ return
199+ await self ._stage_completed .wait ()
200+
182201 async def process_error (self , json_data : List [Dict [str , Any ]]) -> None :
183202 self ._request_state = StreamingState .Error
184203 if not isinstance (json_data , list ):
@@ -209,4 +228,6 @@ async def __aexit__(self,
209228 self ._request_state = StreamingState .Cancelled
210229 elif exc_val is not None :
211230 self ._request_state = StreamingState .Error
212- del self ._taskgroup
231+ del self ._taskgroup
232+ # TODO: should we suppress here (e.g., return True)
233+ return None
0 commit comments