Skip to content

Commit 4a8493b

Browse files
committed
Using SubscriptionTransportBase instead of WebsocketsTransportBase for Phoenix transport
1 parent c369d2a commit 4a8493b

1 file changed

Lines changed: 33 additions & 7 deletions

File tree

gql/transport/phoenix_channel_websockets.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
11
import asyncio
22
import json
33
import logging
4-
from typing import Any, Dict, Optional, Tuple
4+
from typing import Any, Dict, Optional, Tuple, Union
55

66
from graphql import DocumentNode, ExecutionResult, print_ast
7-
from websockets.exceptions import ConnectionClosed
87

8+
from .common.adapters.websockets import WebSocketsAdapter
9+
from .common.base import SubscriptionTransportBase
910
from .exceptions import (
11+
TransportConnectionClosed,
1012
TransportProtocolError,
1113
TransportQueryError,
1214
TransportServerError,
1315
)
14-
from .websockets_base import WebsocketsTransportBase
1516

1617
log = logging.getLogger(__name__)
1718

@@ -24,7 +25,7 @@ def __init__(self, query_id: int) -> None:
2425
self.unsubscribe_id: Optional[int] = None
2526

2627

27-
class PhoenixChannelWebsocketsTransport(WebsocketsTransportBase):
28+
class PhoenixChannelWebsocketsTransport(SubscriptionTransportBase):
2829
"""The PhoenixChannelWebsocketsTransport is an async transport
2930
which allows you to execute queries and subscriptions against an `Absinthe`_
3031
backend using the `Phoenix`_ framework `channels`_.
@@ -36,23 +37,48 @@ class PhoenixChannelWebsocketsTransport(WebsocketsTransportBase):
3637

3738
def __init__(
3839
self,
40+
url: str,
41+
*,
3942
channel_name: str = "__absinthe__:control",
4043
heartbeat_interval: float = 30,
41-
*args,
44+
ack_timeout: Optional[Union[int, float]] = 10,
4245
**kwargs,
4346
) -> None:
4447
"""Initialize the transport with the given parameters.
4548
49+
:param url: The server URL.'.
4650
:param channel_name: Channel on the server this transport will join.
4751
The default for Absinthe servers is "__absinthe__:control"
4852
:param heartbeat_interval: Interval in second between each heartbeat messages
4953
sent by the client
54+
:param ack_timeout: Timeout in seconds to wait for the reply message
55+
from the server.
5056
"""
5157
self.channel_name: str = channel_name
5258
self.heartbeat_interval: float = heartbeat_interval
5359
self.heartbeat_task: Optional[asyncio.Future] = None
5460
self.subscriptions: Dict[str, Subscription] = {}
55-
super().__init__(*args, **kwargs)
61+
self.ack_timeout: Optional[Union[int, float]] = ack_timeout
62+
63+
# Instanciate a WebSocketAdapter to indicate the use
64+
# of the websockets dependency for this transport
65+
ws_adapter_args = {}
66+
for ws_arg in ["headers", "ssl", "connect_args"]:
67+
try:
68+
ws_adapter_args[ws_arg] = kwargs.pop(ws_arg)
69+
except KeyError:
70+
pass
71+
72+
self.adapter: WebSocketsAdapter = WebSocketsAdapter(
73+
url=url,
74+
**ws_adapter_args,
75+
)
76+
77+
# Initialize the generic SubscriptionTransportBase parent class
78+
super().__init__(
79+
adapter=self.adapter,
80+
**kwargs,
81+
)
5682

5783
async def _initialize(self) -> None:
5884
"""Join the specified channel and wait for the connection ACK.
@@ -101,7 +127,7 @@ async def heartbeat_coro():
101127
}
102128
)
103129
)
104-
except ConnectionClosed: # pragma: no cover
130+
except TransportConnectionClosed: # pragma: no cover
105131
return
106132

107133
self.heartbeat_task = asyncio.ensure_future(heartbeat_coro())

0 commit comments

Comments
 (0)