1313# limitations under the License.
1414
1515import asyncio
16+ import contextvars
17+ import inspect
1618import sys
1719import traceback
1820from pathlib import Path
19- from typing import TYPE_CHECKING , Any , Callable , Dict , List , Optional , Union
21+ from typing import TYPE_CHECKING , Any , Callable , Dict , List , Optional , Union , cast
2022
2123from greenlet import greenlet
2224from pyee import AsyncIOEventEmitter , EventEmitter
2325
26+ import playwright
2427from playwright ._impl ._helper import ParsedMessagePayload , parse_error
2528from playwright ._impl ._transport import Transport
2629
@@ -36,10 +39,21 @@ def __init__(self, connection: "Connection", guid: str) -> None:
3639 self ._object : Optional [ChannelOwner ] = None
3740
3841 async def send (self , method : str , params : Dict = None ) -> Any :
39- return await self .inner_send (method , params , False )
42+ return await self ._connection .wrap_api_call (
43+ lambda : self .inner_send (method , params , False )
44+ )
4045
4146 async def send_return_as_dict (self , method : str , params : Dict = None ) -> Any :
42- return await self .inner_send (method , params , True )
47+ return await self ._connection .wrap_api_call (
48+ lambda : self .inner_send (method , params , True )
49+ )
50+
51+ def send_no_reply (self , method : str , params : Dict = None ) -> None :
52+ self ._connection .wrap_api_call (
53+ lambda : self ._connection ._send_message_to_server (
54+ self ._guid , method , {} if params is None else params
55+ )
56+ )
4357
4458 async def inner_send (
4559 self , method : str , params : Optional [Dict ], return_as_dict : bool
@@ -74,11 +88,6 @@ async def inner_send(
7488 key = next (iter (result ))
7589 return result [key ]
7690
77- def send_no_reply (self , method : str , params : Dict = None ) -> None :
78- if params is None :
79- params = {}
80- self ._connection ._send_message_to_server (self ._guid , method , params )
81-
8291
8392class ChannelOwner (AsyncIOEventEmitter ):
8493 def __init__ (
@@ -122,7 +131,7 @@ def _dispose(self) -> None:
122131
123132class ProtocolCallback :
124133 def __init__ (self , loop : asyncio .AbstractEventLoop ) -> None :
125- self .stack_trace : traceback .StackSummary = traceback . StackSummary ()
134+ self .stack_trace : traceback .StackSummary
126135 self .future = loop .create_future ()
127136 # The outer task can get cancelled by the user, this forwards the cancellation to the inner task.
128137 current_task = asyncio .current_task ()
@@ -181,6 +190,9 @@ def __init__(
181190 self ._error : Optional [BaseException ] = None
182191 self .is_remote = False
183192 self ._init_task : Optional [asyncio .Task ] = None
193+ self ._api_zone : contextvars .ContextVar [Optional [Dict ]] = contextvars .ContextVar (
194+ "ApiZone" , default = None
195+ )
184196
185197 def mark_as_remote (self ) -> None :
186198 self .is_remote = True
@@ -230,22 +242,17 @@ def _send_message_to_server(
230242 id = self ._last_id
231243 callback = ProtocolCallback (self ._loop )
232244 task = asyncio .current_task (self ._loop )
233- stack_trace : Optional [traceback .StackSummary ] = getattr (
234- task , "__pw_stack_trace__" , None
245+ callback .stack_trace = cast (
246+ traceback .StackSummary ,
247+ getattr (task , "__pw_stack_trace__" , traceback .extract_stack ()),
235248 )
236- callback .stack_trace = stack_trace or traceback .extract_stack ()
237249 self ._callbacks [id ] = callback
238- metadata = {"stack" : serialize_call_stack (callback .stack_trace )}
239- api_name = getattr (task , "__pw_api_name__" , None )
240- if api_name :
241- metadata ["apiName" ] = api_name
242-
243250 message = {
244251 "id" : id ,
245252 "guid" : guid ,
246253 "method" : method ,
247254 "params" : self ._replace_channels_with_guids (params ),
248- "metadata" : metadata ,
255+ "metadata" : self . _api_zone . get () ,
249256 }
250257 self ._transport .send (message )
251258 self ._callbacks [id ] = callback
@@ -337,6 +344,27 @@ def _replace_guids_with_channels(self, payload: Any) -> Any:
337344 return result
338345 return payload
339346
347+ def wrap_api_call (self , cb : Callable [[], Any ], is_internal : bool = False ) -> Any :
348+ if self ._api_zone .get ():
349+ return cb ()
350+ task = asyncio .current_task (self ._loop )
351+ st : List [inspect .FrameInfo ] = getattr (task , "__pw_stack__" , inspect .stack ())
352+ metadata = _extract_metadata_from_stack (st , is_internal )
353+ if metadata :
354+ self ._api_zone .set (metadata )
355+ result = cb ()
356+
357+ async def _ () -> None :
358+ try :
359+ return await result
360+ finally :
361+ self ._api_zone .set (None )
362+
363+ if asyncio .iscoroutine (result ):
364+ return _ ()
365+ self ._api_zone .set (None )
366+ return result
367+
340368
341369def from_channel (channel : Channel ) -> Any :
342370 return channel ._object
@@ -346,13 +374,40 @@ def from_nullable_channel(channel: Optional[Channel]) -> Optional[Any]:
346374 return channel ._object if channel else None
347375
348376
349- def serialize_call_stack (stack_trace : traceback .StackSummary ) -> List [Dict ]:
377+ def _extract_metadata_from_stack (
378+ st : List [inspect .FrameInfo ], is_internal : bool
379+ ) -> Optional [Dict ]:
380+ playwright_module_path = str (Path (playwright .__file__ ).parents [0 ])
381+ last_internal_api_name = ""
382+ api_name = ""
350383 stack : List [Dict ] = []
351- for frame in stack_trace :
352- if "_generated.py" in frame .filename :
353- break
354- stack .append (
355- {"file" : frame .filename , "line" : frame .lineno , "function" : frame .name }
356- )
357- stack .reverse ()
358- return stack
384+ for frame in st :
385+ is_playwright_internal = frame .filename .startswith (playwright_module_path )
386+
387+ method_name = ""
388+ if "self" in frame [0 ].f_locals :
389+ method_name = frame [0 ].f_locals ["self" ].__class__ .__name__ + "."
390+ method_name += frame [0 ].f_code .co_name
391+
392+ if not is_playwright_internal :
393+ stack .append (
394+ {
395+ "file" : frame .filename ,
396+ "line" : frame .lineno ,
397+ "function" : method_name ,
398+ }
399+ )
400+ if is_playwright_internal :
401+ last_internal_api_name = method_name
402+ elif last_internal_api_name :
403+ api_name = last_internal_api_name
404+ last_internal_api_name = ""
405+ if not api_name :
406+ api_name = last_internal_api_name
407+ if api_name :
408+ return {
409+ "apiName" : api_name ,
410+ "stack" : stack ,
411+ "isInternal" : is_internal ,
412+ }
413+ return None
0 commit comments