Skip to content

Commit 3bf1cc1

Browse files
committed
IGNITE-15479 Fix incorrect partial read from socket in sync client - Fixes #50.
1 parent ef8687e commit 3bf1cc1

2 files changed

Lines changed: 57 additions & 16 deletions

File tree

pyignite/connection/connection.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,9 @@ def _connection_listener(self):
156156
return self.client._event_listeners
157157

158158

159+
DEFAULT_INITIAL_BUF_SIZE = 1024
160+
161+
159162
class Connection(BaseConnection):
160163
"""
161164
This is a `pyignite` class, that represents a connection to Ignite
@@ -348,39 +351,35 @@ def recv(self, flags=None, reconnect=True) -> bytearray:
348351
if flags is not None:
349352
kwargs['flags'] = flags
350353

351-
data = bytearray(1024)
354+
data = bytearray(DEFAULT_INITIAL_BUF_SIZE)
352355
buffer = memoryview(data)
353-
bytes_total_received, bytes_to_receive = 0, 0
356+
total_rcvd, packet_len = 0, 0
354357
while True:
355358
try:
356-
bytes_received = self._socket.recv_into(buffer, len(buffer), **kwargs)
357-
if bytes_received == 0:
359+
bytes_rcvd = self._socket.recv_into(buffer, len(buffer), **kwargs)
360+
if bytes_rcvd == 0:
358361
raise SocketError('Connection broken.')
359-
bytes_total_received += bytes_received
362+
total_rcvd += bytes_rcvd
360363
except connection_errors as e:
361364
self.failed = True
362365
if reconnect:
363366
self._on_connection_lost(e)
364367
self.reconnect()
365368
raise e
366369

367-
if bytes_total_received < 4:
368-
continue
369-
elif bytes_to_receive == 0:
370-
response_len = int.from_bytes(data[0:4], PROTOCOL_BYTE_ORDER)
371-
bytes_to_receive = response_len
372-
373-
if response_len + 4 > len(data):
370+
if packet_len == 0 and total_rcvd > 4:
371+
packet_len = int.from_bytes(data[0:4], PROTOCOL_BYTE_ORDER, signed=True) + 4
372+
if packet_len > len(data):
374373
buffer.release()
375-
data.extend(bytearray(response_len + 4 - len(data)))
376-
buffer = memoryview(data)[bytes_total_received:]
374+
data.extend(bytearray(packet_len - len(data)))
375+
buffer = memoryview(data)[total_rcvd:]
377376
continue
378377

379-
if bytes_total_received >= bytes_to_receive:
378+
if 0 < packet_len <= total_rcvd:
380379
buffer.release()
381380
break
382381

383-
buffer = buffer[bytes_received:]
382+
buffer = buffer[bytes_rcvd:]
384383

385384
return data
386385

tests/common/test_sync_socket.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one or more
2+
# contributor license agreements. See the NOTICE file distributed with
3+
# this work for additional information regarding copyright ownership.
4+
# The ASF licenses this file to You under the Apache License, Version 2.0
5+
# (the "License"); you may not use this file except in compliance with
6+
# the License. You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import secrets
17+
import socket
18+
import unittest.mock as mock
19+
20+
import pytest
21+
22+
from pyignite import Client
23+
from tests.util import get_or_create_cache
24+
25+
old_recv_into = socket.socket.recv_into
26+
27+
28+
def patched_recv_into_factory(buf_len):
29+
def patched_recv_into(self, buffer, nbytes, **kwargs):
30+
return old_recv_into(self, buffer, min(nbytes, buf_len) if buf_len else nbytes, **kwargs)
31+
return patched_recv_into
32+
33+
34+
@pytest.mark.parametrize('buf_len', [0, 1, 4, 16, 32, 64, 128, 256, 512, 1024])
35+
def test_get_large_value(buf_len):
36+
with mock.patch.object(socket.socket, 'recv_into', new=patched_recv_into_factory(buf_len)):
37+
c = Client()
38+
with c.connect("127.0.0.1", 10801):
39+
with get_or_create_cache(c, 'test') as cache:
40+
value = secrets.token_hex((1 << 16) + 1)
41+
cache.put(1, value)
42+
assert value == cache.get(1)

0 commit comments

Comments
 (0)