Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 146 additions & 9 deletions cassandra/io/asyncioreactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,56 @@ def finish(self):
'does not implement .finish()')


class _AsyncioProtocol(asyncio.Protocol):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this protocol owns the transport callbacks, I think it would be useful for it to also own flow-control state. Implementing pause_writing() / resume_writing() here would let the writer respect asyncio's high/low watermarks. It may also be worth unblocking any paused writer from connection_lost() so shutdown cannot leave the write coroutine waiting indefinitely.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — added pause_writing()/resume_writing() with a write_ready Event, and connection_lost() sets it to prevent the writer from hanging on shutdown.

"""
Protocol adapter for asyncio SSL connections. Bridges asyncio's
transport/protocol API back to AsyncioConnection's buffer processing.
"""

def __init__(self, connection, loop_args=None):
self._connection = connection
self.transport = None
self.write_ready = asyncio.Event(**(loop_args or {}))
self.write_ready.set()

def connection_made(self, transport):
self.transport = transport

def data_received(self, data):
conn = self._connection
conn._iobuf.write(data)
if conn._iobuf.tell():
conn.process_io_buffer()

def pause_writing(self):
self.write_ready.clear()

def resume_writing(self):
self.write_ready.set()

def connection_lost(self, exc):
# Unblock any paused writer so shutdown does not hang
self.write_ready.set()
conn = self._connection
if exc:
log.debug("Connection %s lost: %s", conn, exc)
conn.defunct(exc)
else:
log.debug("Connection %s closed by server", conn)
conn.close()

def eof_received(self):
return False


class AsyncioConnection(Connection):
"""
An experimental implementation of :class:`.Connection` that uses the
``asyncio`` module in the Python standard library for its event loop.
An implementation of :class:`.Connection` that uses the ``asyncio``
module in the Python standard library for its event loop.

Note that it requires ``asyncio`` features that were only introduced in the
3.4 line in 3.4.6, and in the 3.5 line in 3.5.1.
Supports SSL connections via asyncio's native TLS transport, which
avoids the incompatibility between ``ssl.SSLSocket`` and asyncio's
low-level socket methods (``sock_sendall``, ``sock_recv``).
"""

_loop = None
Expand All @@ -88,26 +131,106 @@ class AsyncioConnection(Connection):
def __init__(self, *args, **kwargs):
Connection.__init__(self, *args, **kwargs)
self._background_tasks = set()
self._transport = None
self._using_ssl = bool(self.ssl_context)

self._connect_socket()
self._socket.setblocking(0)
loop_args = dict()
if sys.version_info[0] == 3 and sys.version_info[1] < 10:
loop_args['loop'] = self._loop
self._protocol = _AsyncioProtocol(self, loop_args) if self._using_ssl else None
self._ssl_ready = asyncio.Event(**loop_args) if self._using_ssl else None
self._write_queue = asyncio.Queue(**loop_args)
self._write_queue_lock = asyncio.Lock(**loop_args)

# see initialize_reactor -- loop is running in a separate thread, so we
# have to use a threadsafe call
self._read_watcher = asyncio.run_coroutine_threadsafe(
self.handle_read(), loop=self._loop
)
if self._using_ssl:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to use the transport/protocol path for both TLS and non-TLS connections here? With the current split, asyncio has two different I/O models in the same reactor: TLS uses asyncio.Protocol/transport, while plain CQL stays on sock_recv()/sock_sendall().

I think unifying them in this PR may be cleaner because the work needed for TLS is mostly the same work needed for a correct transport-based reactor: setup, read callbacks, write flow control, close handling, and tests. If those pieces are implemented only for TLS now, we end up with two connection state machines to maintain and a higher chance of future fixes applying to one path but not the other.

A shared create_connection(sock=..., ssl=self.ssl_context or None) path could make lifecycle, reads, writes, and flow control consistent across both modes.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good suggestion for a follow-up, but I'd rather keep it out of scope for this PR. The non-SSL sock_recv/sock_sendall path has been working in production for years — rewriting it to use transport/protocol introduces risk with no immediate benefit. The two paths are isolated, so future fixes to one won't accidentally break the other. Happy to do this as a separate PR if we want to modernize the reactor.

# For SSL: set up asyncio transport/protocol, then start writer
self._read_watcher = asyncio.run_coroutine_threadsafe(
self._setup_ssl_and_run(), loop=self._loop
)
else:
# For non-SSL: use low-level sock_sendall/sock_recv as before
self._read_watcher = asyncio.run_coroutine_threadsafe(
self.handle_read(), loop=self._loop
)
self._write_watcher = asyncio.run_coroutine_threadsafe(
self.handle_write(), loop=self._loop
)
self._send_options_message()

def _connect_socket(self):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This override looks like it accidentally skips the base class sockopts handling. It would be good to preserve Cluster(sockopts=...) for asyncio connections:

if self.sockopts:
    for args in self.sockopts:
        self._socket.setsockopt(*args)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Added if self.sockopts: block at the end of _connect_socket(), matching the base class.

"""
Override base class to skip SSL wrapping of the socket.
For SSL connections, the plain TCP socket is connected here, and TLS
is set up later via asyncio's native SSL transport in _setup_ssl_and_run().
"""
sockerr = None
addresses = self._get_socket_addresses()
for af, socktype, proto, _, sockaddr in addresses:
try:
self._socket = self._socket_impl.socket(af, socktype, proto)
# Do NOT wrap with ssl_context here -- asyncio will handle TLS
self._socket.settimeout(self.connect_timeout)
self._initiate_connection(sockaddr)
self._socket.settimeout(None)

local_addr = self._socket.getsockname()
log.debug("Connection %s: '%s' -> '%s'", id(self), local_addr, sockaddr)
sockerr = None
break
except socket.error as err:
if self._socket:
self._socket.close()
self._socket = None
sockerr = err

if sockerr:
raise socket.error(
sockerr.errno,
"Tried connecting to %s. Last error: %s"
% ([a[4] for a in addresses], sockerr.strerror or sockerr),
)

if self.sockopts:
for args in self.sockopts:
self._socket.setsockopt(*args)

async def _setup_ssl_and_run(self):
"""
Upgrade the plain TCP connection to TLS using asyncio's native SSL
transport, then continuously read data via the protocol callbacks.
"""
try:
ssl_context = self.ssl_context
server_hostname = None
if self.ssl_options:
server_hostname = self.ssl_options.get("server_hostname", None)
if server_hostname is None:
# asyncio's create_connection requires server_hostname when
# ssl= is set, even if check_hostname is False
server_hostname = self.endpoint.address
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The old wrap_socket() path only auto-populated server_hostname when hostname verification was enabled, or when the caller passed it explicitly. This unconditional fallback changes that behavior by sending SNI for every TLS connection, which can change certificate selection behind SNI-routing proxies even when check_hostname is off. To preserve the previous semantics here, only default to endpoint.address when the context is doing hostname verification; otherwise pass "".

Suggested change
server_hostname = self.endpoint.address
server_hostname = self.endpoint.address if ssl_context.check_hostname else ""


transport, protocol = await self._loop.create_connection(
lambda: self._protocol,
sock=self._socket,
ssl=ssl_context,
server_hostname=server_hostname,
)
self._transport = transport

if self._check_hostname:
self._validate_hostname()

self._ssl_ready.set()
except Exception as exc:
log.debug("SSL setup failed for %s: %s", self, exc)
self.defunct(exc)
# Unblock handle_write so it can observe the defunct state and exit
self._ssl_ready.set()
return

@classmethod
def initialize_reactor(cls):
Expand Down Expand Up @@ -152,7 +275,10 @@ async def _close(self):
self._write_watcher.cancel()
if self._read_watcher:
self._read_watcher.cancel()
if self._socket:
if self._transport:
self._transport.close()
self._transport = None
elif self._socket:
self._loop.remove_writer(self._socket.fileno())
self._loop.remove_reader(self._socket.fileno())
self._socket.close()
Expand Down Expand Up @@ -196,11 +322,22 @@ async def _push_msg(self, chunks):


async def handle_write(self):
# For SSL connections, wait until the TLS handshake completes
if self._ssl_ready:
await self._ssl_ready.wait()
if self.is_defunct:
return
while True:
try:
next_msg = await self._write_queue.get()
if next_msg:
await self._loop.sock_sendall(self._socket, next_msg)
if self._transport:
# SSL: use asyncio transport (handles TLS transparently)
await self._protocol.write_ready.wait()
self._transport.write(next_msg)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we account for transport backpressure here? transport.write() queues into asyncio's transport buffer and returns immediately, unlike await sock_sendall(). Under sustained writes, this loop can drain _write_queue faster than the socket can actually send.

One possible shape is to have the protocol expose a write-ready event:

def pause_writing(self):
    self.write_ready.clear()

def resume_writing(self):
    self.write_ready.set()

and then wait before writing:

await self._protocol.write_ready.wait()
self._transport.write(next_msg)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. Implemented pause_writing()/resume_writing() on _AsyncioProtocol with a write_ready Event. The write loop now awaits write_ready.wait() before each transport.write(). Also unblocks from connection_lost() to prevent shutdown hangs.

Comment on lines +336 to +337
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

connection_lost() now sets write_ready to unblock a paused writer during shutdown. That means this await can resume before _close() has canceled _write_watcher; in that window _transport may already be closing or set to None, and the next line will still call write() on it. Please recheck the state after the wait before writing.

Suggested change
await self._protocol.write_ready.wait()
self._transport.write(next_msg)
await self._protocol.write_ready.wait()
if self.is_closed or self.is_defunct or not self._transport:
return
self._transport.write(next_msg)

else:
# Non-SSL: use low-level socket API
await self._loop.sock_sendall(self._socket, next_msg)
except socket.error as err:
log.debug("Exception in send for %s: %s", self, err)
self.defunct(err)
Expand Down
Loading