Skip to content

Commit f00d70f

Browse files
ivandaschisapego
authored andcommitted
IGNITE-14432 Implement connection context managers for clients
This closes #23
1 parent 0bcd77f commit f00d70f

8 files changed

Lines changed: 166 additions & 61 deletions

File tree

pyignite/aio_client.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,22 @@
3333
__all__ = ['AioClient']
3434

3535

36+
class _ConnectionContextManager:
37+
def __init__(self, client, nodes):
38+
self.client = client
39+
self.nodes = nodes
40+
41+
def __await__(self):
42+
return (yield from self.__aenter__().__await__())
43+
44+
async def __aenter__(self):
45+
await self.client._connect(self.nodes)
46+
return self
47+
48+
async def __aexit__(self, exc_type, exc_val, exc_tb):
49+
await self.client.close()
50+
51+
3652
class AioClient(BaseClient):
3753
"""
3854
Asynchronous Client implementation.
@@ -57,14 +73,16 @@ def __init__(self, compact_footer: bool = None, partition_aware: bool = False, *
5773
super().__init__(compact_footer, partition_aware, **kwargs)
5874
self._registry_mux = asyncio.Lock()
5975

60-
async def connect(self, *args):
76+
def connect(self, *args):
6177
"""
6278
Connect to Ignite cluster node(s).
6379
6480
:param args: (optional) host(s) and port(s) to connect to.
6581
"""
6682
nodes = self._process_connect_args(*args)
83+
return _ConnectionContextManager(self, nodes)
6784

85+
async def _connect(self, nodes):
6886
for i, node in enumerate(nodes):
6987
host, port = node
7088
conn = AioConnection(self, host, port, **self._connection_args)

pyignite/client.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,19 @@ def _get_from_registry(self, type_id, schema):
243243
return self._registry[type_id]
244244

245245

246+
class _ConnectionContextManager:
247+
def __init__(self, client, nodes):
248+
self.client = client
249+
self.nodes = nodes
250+
self.client._connect(self.nodes)
251+
252+
def __enter__(self):
253+
return self
254+
255+
def __exit__(self, exc_type, exc_val, exc_tb):
256+
self.client.close()
257+
258+
246259
class Client(BaseClient):
247260
"""
248261
This is a main `pyignite` class, that is build upon the
@@ -280,7 +293,9 @@ def connect(self, *args):
280293
:param args: (optional) host(s) and port(s) to connect to.
281294
"""
282295
nodes = self._process_connect_args(*args)
296+
return _ConnectionContextManager(self, nodes)
283297

298+
def _connect(self, nodes):
284299
# the following code is quite twisted, because the protocol version
285300
# is initially unknown
286301

tests/affinity/conftest.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,20 +39,25 @@ def server3():
3939

4040

4141
@pytest.fixture
42-
def client():
42+
def connection_param():
43+
return [('127.0.0.1', 10800 + i) for i in range(1, 4)]
44+
45+
46+
@pytest.fixture
47+
def client(connection_param):
4348
client = Client(partition_aware=True, timeout=CLIENT_SOCKET_TIMEOUT)
4449
try:
45-
client.connect([('127.0.0.1', 10800 + i) for i in range(1, 4)])
50+
client.connect(connection_param)
4651
yield client
4752
finally:
4853
client.close()
4954

5055

5156
@pytest.fixture
52-
async def async_client():
57+
async def async_client(connection_param):
5358
client = AioClient(partition_aware=True)
5459
try:
55-
await client.connect([('127.0.0.1', 10800 + i) for i in range(1, 4)])
60+
await client.connect(connection_param)
5661
yield client
5762
finally:
5863
await client.close()

tests/affinity/test_affinity_bad_servers.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@
1515

1616
import pytest
1717

18+
from pyignite import Client, AioClient
1819
from pyignite.exceptions import ReconnectError, connection_errors
1920
from tests.affinity.conftest import CLIENT_SOCKET_TIMEOUT
20-
from tests.util import start_ignite, kill_process_tree, get_client, get_client_async
21+
from tests.util import start_ignite, kill_process_tree
2122

2223

2324
@pytest.fixture(params=['with-partition-awareness', 'without-partition-awareness'])
@@ -27,22 +28,24 @@ def with_partition_awareness(request):
2728

2829
def test_client_with_multiple_bad_servers(with_partition_awareness):
2930
with pytest.raises(ReconnectError, match="Can not connect."):
30-
with get_client(partition_aware=with_partition_awareness) as client:
31-
client.connect([("127.0.0.1", 10900), ("127.0.0.1", 10901)])
31+
client = Client(partition_aware=with_partition_awareness)
32+
with client.connect([("127.0.0.1", 10900), ("127.0.0.1", 10901)]):
33+
pass
3234

3335

3436
@pytest.mark.asyncio
3537
async def test_client_with_multiple_bad_servers_async(with_partition_awareness):
3638
with pytest.raises(ReconnectError, match="Can not connect."):
37-
async with get_client_async(partition_aware=with_partition_awareness) as client:
38-
await client.connect([("127.0.0.1", 10900), ("127.0.0.1", 10901)])
39+
client = AioClient(partition_aware=with_partition_awareness)
40+
async with client.connect([("127.0.0.1", 10900), ("127.0.0.1", 10901)]):
41+
pass
3942

4043

4144
def test_client_with_failed_server(request, with_partition_awareness):
4245
srv = start_ignite(idx=4)
4346
try:
44-
with get_client(partition_aware=with_partition_awareness) as client:
45-
client.connect([("127.0.0.1", 10804)])
47+
client = Client(partition_aware=with_partition_awareness)
48+
with client.connect([("127.0.0.1", 10804)]):
4649
cache = client.get_or_create_cache(request.node.name)
4750
cache.put(1, 1)
4851
kill_process_tree(srv.pid)
@@ -62,8 +65,8 @@ def test_client_with_failed_server(request, with_partition_awareness):
6265
async def test_client_with_failed_server_async(request, with_partition_awareness):
6366
srv = start_ignite(idx=4)
6467
try:
65-
async with get_client_async(partition_aware=with_partition_awareness) as client:
66-
await client.connect([("127.0.0.1", 10804)])
68+
client = AioClient(partition_aware=with_partition_awareness)
69+
async with client.connect([("127.0.0.1", 10804)]):
6770
cache = await client.get_or_create_cache(request.node.name)
6871
await cache.put(1, 1)
6972
kill_process_tree(srv.pid)
@@ -82,8 +85,8 @@ async def test_client_with_failed_server_async(request, with_partition_awareness
8285
def test_client_with_recovered_server(request, with_partition_awareness):
8386
srv = start_ignite(idx=4)
8487
try:
85-
with get_client(partition_aware=with_partition_awareness, timeout=CLIENT_SOCKET_TIMEOUT) as client:
86-
client.connect([("127.0.0.1", 10804)])
88+
client = Client(partition_aware=with_partition_awareness, timeout=CLIENT_SOCKET_TIMEOUT)
89+
with client.connect([("127.0.0.1", 10804)]):
8790
cache = client.get_or_create_cache(request.node.name)
8891
cache.put(1, 1)
8992

@@ -108,8 +111,8 @@ def test_client_with_recovered_server(request, with_partition_awareness):
108111
async def test_client_with_recovered_server_async(request, with_partition_awareness):
109112
srv = start_ignite(idx=4)
110113
try:
111-
async with get_client_async(partition_aware=with_partition_awareness) as client:
112-
await client.connect([("127.0.0.1", 10804)])
114+
client = AioClient(partition_aware=with_partition_awareness)
115+
async with client.connect([("127.0.0.1", 10804)]):
113116
cache = await client.get_or_create_cache(request.node.name)
114117
await cache.put(1, 1)
115118

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one or more
2+
# contributor license agreements. See the NOTICE file distributed with
3+
# this work for additional information regarding copyright ownership.
4+
# The ASF licenses this file to You under the Apache License, Version 2.0
5+
# (the "License"); you may not use this file except in compliance with
6+
# the License. You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import pytest
17+
18+
from pyignite import Client, AioClient
19+
20+
21+
@pytest.fixture
22+
def connection_param():
23+
return [('127.0.0.1', 10800 + i) for i in range(1, 4)]
24+
25+
26+
@pytest.mark.parametrize('partition_aware', ['with_partition_aware', 'wo_partition_aware'])
27+
def test_connection_context(connection_param, partition_aware):
28+
is_partition_aware = partition_aware == 'with_partition_aware'
29+
client = Client(partition_aware=is_partition_aware)
30+
31+
# Check context manager
32+
with client.connect(connection_param):
33+
__check_open(client, is_partition_aware)
34+
__check_closed(client)
35+
36+
# Check standard way
37+
try:
38+
client.connect(connection_param)
39+
__check_open(client, is_partition_aware)
40+
finally:
41+
client.close()
42+
__check_closed(client)
43+
44+
45+
@pytest.mark.asyncio
46+
@pytest.mark.parametrize('partition_aware', ['with_partition_aware', 'wo_partition_aware'])
47+
async def test_connection_context_async(connection_param, partition_aware):
48+
is_partition_aware = partition_aware == 'with_partition_aware'
49+
client = AioClient(partition_aware=is_partition_aware)
50+
51+
# Check async context manager.
52+
async with client.connect(connection_param):
53+
await __check_open(client, is_partition_aware)
54+
__check_closed(client)
55+
56+
# Check standard way.
57+
try:
58+
await client.connect(connection_param)
59+
await __check_open(client, is_partition_aware)
60+
finally:
61+
await client.close()
62+
__check_closed(client)
63+
64+
65+
def __check_open(client, is_partition_aware):
66+
def inner_sync():
67+
if is_partition_aware:
68+
assert client.random_node.alive
69+
else:
70+
all(n.alive for n in client._nodes)
71+
72+
async def inner_async():
73+
if is_partition_aware:
74+
random_node = await client.random_node()
75+
assert random_node.alive
76+
else:
77+
all(n.alive for n in client._nodes)
78+
79+
return inner_sync() if isinstance(client, Client) else inner_async()
80+
81+
82+
def __check_closed(client):
83+
assert all(not n.alive for n in client._nodes)

tests/security/test_auth.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414
# limitations under the License.
1515
import pytest
1616

17+
from pyignite import Client, AioClient
1718
from pyignite.exceptions import AuthenticationError
18-
from tests.util import start_ignite_gen, clear_ignite_work_dir, get_client, get_client_async
19+
from tests.util import start_ignite_gen, clear_ignite_work_dir
1920

2021
DEFAULT_IGNITE_USERNAME = 'ignite'
2122
DEFAULT_IGNITE_PASSWORD = 'ignite'
@@ -40,21 +41,16 @@ def cleanup():
4041

4142
def test_auth_success(with_ssl, ssl_params):
4243
ssl_params['use_ssl'] = with_ssl
43-
44-
with get_client(username=DEFAULT_IGNITE_USERNAME, password=DEFAULT_IGNITE_PASSWORD, **ssl_params) as client:
45-
client.connect("127.0.0.1", 10801)
46-
44+
client = Client(username=DEFAULT_IGNITE_USERNAME, password=DEFAULT_IGNITE_PASSWORD, **ssl_params)
45+
with client.connect("127.0.0.1", 10801):
4746
assert all(node.alive for node in client._nodes)
4847

4948

5049
@pytest.mark.asyncio
5150
async def test_auth_success_async(with_ssl, ssl_params):
5251
ssl_params['use_ssl'] = with_ssl
53-
54-
async with get_client_async(username=DEFAULT_IGNITE_USERNAME, password=DEFAULT_IGNITE_PASSWORD,
55-
**ssl_params) as client:
56-
await client.connect("127.0.0.1", 10801)
57-
52+
client = AioClient(username=DEFAULT_IGNITE_USERNAME, password=DEFAULT_IGNITE_PASSWORD, **ssl_params)
53+
async with client.connect("127.0.0.1", 10801):
5854
assert all(node.alive for node in client._nodes)
5955

6056

@@ -73,8 +69,9 @@ def test_auth_failed(username, password, with_ssl, ssl_params):
7369
ssl_params['use_ssl'] = with_ssl
7470

7571
with pytest.raises(AuthenticationError):
76-
with get_client(username=username, password=password, **ssl_params) as client:
77-
client.connect("127.0.0.1", 10801)
72+
client = Client(username=username, password=password, **ssl_params)
73+
with client.connect("127.0.0.1", 10801):
74+
pass
7875

7976

8077
@pytest.mark.parametrize(
@@ -86,5 +83,6 @@ async def test_auth_failed_async(username, password, with_ssl, ssl_params):
8683
ssl_params['use_ssl'] = with_ssl
8784

8885
with pytest.raises(AuthenticationError):
89-
async with get_client_async(username=username, password=password, **ssl_params) as client:
90-
await client.connect("127.0.0.1", 10801)
86+
client = AioClient(username=username, password=password, **ssl_params)
87+
async with client.connect("127.0.0.1", 10801):
88+
pass

tests/security/test_ssl.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414
# limitations under the License.
1515
import pytest
1616

17+
from pyignite import Client, AioClient
1718
from pyignite.exceptions import ReconnectError
18-
from tests.util import start_ignite_gen, get_client, get_or_create_cache, get_client_async, get_or_create_cache_async
19+
from tests.util import start_ignite_gen, get_or_create_cache, get_or_create_cache_async
1920

2021

2122
@pytest.fixture(scope='module', autouse=True)
@@ -45,18 +46,16 @@ def __test_connect_ssl(is_async=False, **kwargs):
4546
kwargs['use_ssl'] = True
4647

4748
def inner():
48-
with get_client(**kwargs) as client:
49-
client.connect("127.0.0.1", 10801)
50-
49+
client = Client(**kwargs)
50+
with client.connect("127.0.0.1", 10801):
5151
with get_or_create_cache(client, 'test-cache') as cache:
5252
cache.put(1, 1)
5353

5454
assert cache.get(1) == 1
5555

5656
async def inner_async():
57-
async with get_client_async(**kwargs) as client:
58-
await client.connect("127.0.0.1", 10801)
59-
57+
client = AioClient(**kwargs)
58+
async with client.connect("127.0.0.1", 10801):
6059
async with get_or_create_cache_async(client, 'test-cache') as cache:
6160
await cache.put(1, 1)
6261

@@ -75,13 +74,15 @@ async def inner_async():
7574
@pytest.mark.parametrize('invalid_ssl_params', invalid_params)
7675
def test_connection_error_with_incorrect_config(invalid_ssl_params):
7776
with pytest.raises(ReconnectError):
78-
with get_client(**invalid_ssl_params) as client:
79-
client.connect([("127.0.0.1", 10801)])
77+
client = Client(**invalid_ssl_params)
78+
with client.connect([("127.0.0.1", 10801)]):
79+
pass
8080

8181

8282
@pytest.mark.parametrize('invalid_ssl_params', invalid_params)
8383
@pytest.mark.asyncio
8484
async def test_connection_error_with_incorrect_config_async(invalid_ssl_params):
8585
with pytest.raises(ReconnectError):
86-
async with get_client_async(**invalid_ssl_params) as client:
87-
await client.connect([("127.0.0.1", 10801)])
86+
client = AioClient(**invalid_ssl_params)
87+
async with client.connect([("127.0.0.1", 10801)]):
88+
pass

tests/util.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -34,24 +34,6 @@
3434
from async_generator import asynccontextmanager
3535

3636

37-
@contextlib.contextmanager
38-
def get_client(**kwargs):
39-
client = Client(**kwargs)
40-
try:
41-
yield client
42-
finally:
43-
client.close()
44-
45-
46-
@asynccontextmanager
47-
async def get_client_async(**kwargs):
48-
client = AioClient(**kwargs)
49-
try:
50-
yield client
51-
finally:
52-
await client.close()
53-
54-
5537
@contextlib.contextmanager
5638
def get_or_create_cache(client, cache_name):
5739
cache = client.get_or_create_cache(cache_name)

0 commit comments

Comments
 (0)