11import asyncio
22import json
33import logging
4- from typing import Any , Dict , Optional , Tuple
4+ from typing import Any , Dict , Optional , Tuple , Union
55
66from graphql import DocumentNode , ExecutionResult , print_ast
7- from websockets .exceptions import ConnectionClosed
87
8+ from .common .adapters .websockets import WebSocketsAdapter
9+ from .common .base import SubscriptionTransportBase
910from .exceptions import (
11+ TransportConnectionClosed ,
1012 TransportProtocolError ,
1113 TransportQueryError ,
1214 TransportServerError ,
1315)
14- from .websockets_base import WebsocketsTransportBase
1516
1617log = logging .getLogger (__name__ )
1718
@@ -24,7 +25,7 @@ def __init__(self, query_id: int) -> None:
2425 self .unsubscribe_id : Optional [int ] = None
2526
2627
27- class PhoenixChannelWebsocketsTransport (WebsocketsTransportBase ):
28+ class PhoenixChannelWebsocketsTransport (SubscriptionTransportBase ):
2829 """The PhoenixChannelWebsocketsTransport is an async transport
2930 which allows you to execute queries and subscriptions against an `Absinthe`_
3031 backend using the `Phoenix`_ framework `channels`_.
@@ -36,23 +37,48 @@ class PhoenixChannelWebsocketsTransport(WebsocketsTransportBase):
3637
3738 def __init__ (
3839 self ,
40+ url : str ,
41+ * ,
3942 channel_name : str = "__absinthe__:control" ,
4043 heartbeat_interval : float = 30 ,
41- * args ,
44+ ack_timeout : Optional [ Union [ int , float ]] = 10 ,
4245 ** kwargs ,
4346 ) -> None :
4447 """Initialize the transport with the given parameters.
4548
49+ :param url: The server URL.'.
4650 :param channel_name: Channel on the server this transport will join.
4751 The default for Absinthe servers is "__absinthe__:control"
4852 :param heartbeat_interval: Interval in second between each heartbeat messages
4953 sent by the client
54+ :param ack_timeout: Timeout in seconds to wait for the reply message
55+ from the server.
5056 """
5157 self .channel_name : str = channel_name
5258 self .heartbeat_interval : float = heartbeat_interval
5359 self .heartbeat_task : Optional [asyncio .Future ] = None
5460 self .subscriptions : Dict [str , Subscription ] = {}
55- super ().__init__ (* args , ** kwargs )
61+ self .ack_timeout : Optional [Union [int , float ]] = ack_timeout
62+
63+ # Instanciate a WebSocketAdapter to indicate the use
64+ # of the websockets dependency for this transport
65+ ws_adapter_args = {}
66+ for ws_arg in ["headers" , "ssl" , "connect_args" ]:
67+ try :
68+ ws_adapter_args [ws_arg ] = kwargs .pop (ws_arg )
69+ except KeyError :
70+ pass
71+
72+ self .adapter : WebSocketsAdapter = WebSocketsAdapter (
73+ url = url ,
74+ ** ws_adapter_args ,
75+ )
76+
77+ # Initialize the generic SubscriptionTransportBase parent class
78+ super ().__init__ (
79+ adapter = self .adapter ,
80+ ** kwargs ,
81+ )
5682
5783 async def _initialize (self ) -> None :
5884 """Join the specified channel and wait for the connection ACK.
@@ -101,7 +127,7 @@ async def heartbeat_coro():
101127 }
102128 )
103129 )
104- except ConnectionClosed : # pragma: no cover
130+ except TransportConnectionClosed : # pragma: no cover
105131 return
106132
107133 self .heartbeat_task = asyncio .ensure_future (heartbeat_coro ())
0 commit comments