Skip to content

Commit 33ff593

Browse files
isapegoivandasch
andauthored
GG-32946 [IGNITE-14432] Implement connection context managers for clients (#33)
(cherry picked from commit f00d70f) Co-authored-by: Ivan Dashchinskiy <ivandasch@gmail.com>
1 parent ec24b29 commit 33ff593

8 files changed

Lines changed: 167 additions & 61 deletions

File tree

pygridgain/aio_client.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,22 @@
3333
__all__ = ['AioClient']
3434

3535

36+
class _ConnectionContextManager:
37+
def __init__(self, client, nodes):
38+
self.client = client
39+
self.nodes = nodes
40+
41+
def __await__(self):
42+
return (yield from self.__aenter__().__await__())
43+
44+
async def __aenter__(self):
45+
await self.client._connect(self.nodes)
46+
return self
47+
48+
async def __aexit__(self, exc_type, exc_val, exc_tb):
49+
await self.client.close()
50+
51+
3652
class AioClient(BaseClient):
3753
"""
3854
Asynchronous Client implementation.
@@ -57,14 +73,16 @@ def __init__(self, compact_footer: bool = None, partition_aware: bool = False, *
5773
super().__init__(compact_footer, partition_aware, **kwargs)
5874
self._registry_mux = asyncio.Lock()
5975

60-
async def connect(self, *args):
76+
def connect(self, *args):
6177
"""
6278
Connect to Ignite cluster node(s).
6379
6480
:param args: (optional) host(s) and port(s) to connect to.
6581
"""
6682
nodes = self._process_connect_args(*args)
83+
return _ConnectionContextManager(self, nodes)
6784

85+
async def _connect(self, nodes):
6886
for i, node in enumerate(nodes):
6987
host, port = node
7088
conn = AioConnection(self, host, port, **self._connection_args)

pygridgain/client.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,19 @@ def _get_from_registry(self, type_id, schema):
243243
return self._registry[type_id]
244244

245245

246+
class _ConnectionContextManager:
247+
def __init__(self, client, nodes):
248+
self.client = client
249+
self.nodes = nodes
250+
self.client._connect(self.nodes)
251+
252+
def __enter__(self):
253+
return self
254+
255+
def __exit__(self, exc_type, exc_val, exc_tb):
256+
self.client.close()
257+
258+
246259
class Client(BaseClient):
247260
"""
248261
This is a main `pygridgain` class, that is build upon the
@@ -280,7 +293,9 @@ def connect(self, *args):
280293
:param args: (optional) host(s) and port(s) to connect to.
281294
"""
282295
nodes = self._process_connect_args(*args)
296+
return _ConnectionContextManager(self, nodes)
283297

298+
def _connect(self, nodes):
284299
# the following code is quite twisted, because the protocol version
285300
# is initially unknown
286301

tests/affinity/conftest.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,20 +39,25 @@ def server3():
3939

4040

4141
@pytest.fixture
42-
def client():
42+
def connection_param():
43+
return [('127.0.0.1', 10800 + i) for i in range(1, 4)]
44+
45+
46+
@pytest.fixture
47+
def client(connection_param):
4348
client = Client(partition_aware=True, timeout=CLIENT_SOCKET_TIMEOUT)
4449
try:
45-
client.connect([('127.0.0.1', 10800 + i) for i in range(1, 4)])
50+
client.connect(connection_param)
4651
yield client
4752
finally:
4853
client.close()
4954

5055

5156
@pytest.fixture
52-
async def async_client():
57+
async def async_client(connection_param):
5358
client = AioClient(partition_aware=True)
5459
try:
55-
await client.connect([('127.0.0.1', 10800 + i) for i in range(1, 4)])
60+
await client.connect(connection_param)
5661
yield client
5762
finally:
5863
await client.close()

tests/affinity/test_affinity_bad_servers.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@
1515
#
1616
import pytest
1717

18+
from pygridgain import Client, AioClient
1819
from pygridgain.exceptions import ReconnectError, connection_errors
1920
from tests.affinity.conftest import CLIENT_SOCKET_TIMEOUT
20-
from tests.util import start_ignite, kill_process_tree, get_client, get_client_async
21+
from tests.util import start_ignite, kill_process_tree
2122

2223
@pytest.fixture(params=['with-partition-awareness', 'without-partition-awareness'])
2324
def with_partition_awareness(request):
@@ -26,22 +27,24 @@ def with_partition_awareness(request):
2627

2728
def test_client_with_multiple_bad_servers(with_partition_awareness):
2829
with pytest.raises(ReconnectError, match="Can not connect."):
29-
with get_client(partition_aware=with_partition_awareness) as client:
30-
client.connect([("127.0.0.1", 10900), ("127.0.0.1", 10901)])
30+
client = Client(partition_aware=with_partition_awareness)
31+
with client.connect([("127.0.0.1", 10900), ("127.0.0.1", 10901)]):
32+
pass
3133

3234

3335
@pytest.mark.asyncio
3436
async def test_client_with_multiple_bad_servers_async(with_partition_awareness):
3537
with pytest.raises(ReconnectError, match="Can not connect."):
36-
async with get_client_async(partition_aware=with_partition_awareness) as client:
37-
await client.connect([("127.0.0.1", 10900), ("127.0.0.1", 10901)])
38+
client = AioClient(partition_aware=with_partition_awareness)
39+
async with client.connect([("127.0.0.1", 10900), ("127.0.0.1", 10901)]):
40+
pass
3841

3942

4043
def test_client_with_failed_server(request, with_partition_awareness):
4144
srv = start_ignite(idx=4)
4245
try:
43-
with get_client(partition_aware=with_partition_awareness) as client:
44-
client.connect([("127.0.0.1", 10804)])
46+
client = Client(partition_aware=with_partition_awareness)
47+
with client.connect([("127.0.0.1", 10804)]):
4548
cache = client.get_or_create_cache(request.node.name)
4649
cache.put(1, 1)
4750
kill_process_tree(srv.pid)
@@ -61,8 +64,8 @@ def test_client_with_failed_server(request, with_partition_awareness):
6164
async def test_client_with_failed_server_async(request, with_partition_awareness):
6265
srv = start_ignite(idx=4)
6366
try:
64-
async with get_client_async(partition_aware=with_partition_awareness) as client:
65-
await client.connect([("127.0.0.1", 10804)])
67+
client = AioClient(partition_aware=with_partition_awareness)
68+
async with client.connect([("127.0.0.1", 10804)]):
6669
cache = await client.get_or_create_cache(request.node.name)
6770
await cache.put(1, 1)
6871
kill_process_tree(srv.pid)
@@ -81,8 +84,8 @@ async def test_client_with_failed_server_async(request, with_partition_awareness
8184
def test_client_with_recovered_server(request, with_partition_awareness):
8285
srv = start_ignite(idx=4)
8386
try:
84-
with get_client(partition_aware=with_partition_awareness, timeout=CLIENT_SOCKET_TIMEOUT) as client:
85-
client.connect([("127.0.0.1", 10804)])
87+
client = Client(partition_aware=with_partition_awareness, timeout=CLIENT_SOCKET_TIMEOUT)
88+
with client.connect([("127.0.0.1", 10804)]):
8689
cache = client.get_or_create_cache(request.node.name)
8790
cache.put(1, 1)
8891

@@ -107,8 +110,8 @@ def test_client_with_recovered_server(request, with_partition_awareness):
107110
async def test_client_with_recovered_server_async(request, with_partition_awareness):
108111
srv = start_ignite(idx=4)
109112
try:
110-
async with get_client_async(partition_aware=with_partition_awareness) as client:
111-
await client.connect([("127.0.0.1", 10804)])
113+
client = AioClient(partition_aware=with_partition_awareness)
114+
async with client.connect([("127.0.0.1", 10804)]):
112115
cache = await client.get_or_create_cache(request.node.name)
113116
await cache.put(1, 1)
114117

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
#
2+
# Copyright 2021 GridGain Systems, Inc. and Contributors.
3+
#
4+
# Licensed under the GridGain Community Edition License (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# https://www.gridgain.com/products/software/community-edition/gridgain-community-edition-license
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
17+
import pytest
18+
19+
from pygridgain import Client, AioClient
20+
21+
22+
@pytest.fixture
23+
def connection_param():
24+
return [('127.0.0.1', 10800 + i) for i in range(1, 4)]
25+
26+
27+
@pytest.mark.parametrize('partition_aware', ['with_partition_aware', 'wo_partition_aware'])
28+
def test_connection_context(connection_param, partition_aware):
29+
is_partition_aware = partition_aware == 'with_partition_aware'
30+
client = Client(partition_aware=is_partition_aware)
31+
32+
# Check context manager
33+
with client.connect(connection_param):
34+
__check_open(client, is_partition_aware)
35+
__check_closed(client)
36+
37+
# Check standard way
38+
try:
39+
client.connect(connection_param)
40+
__check_open(client, is_partition_aware)
41+
finally:
42+
client.close()
43+
__check_closed(client)
44+
45+
46+
@pytest.mark.asyncio
47+
@pytest.mark.parametrize('partition_aware', ['with_partition_aware', 'wo_partition_aware'])
48+
async def test_connection_context_async(connection_param, partition_aware):
49+
is_partition_aware = partition_aware == 'with_partition_aware'
50+
client = AioClient(partition_aware=is_partition_aware)
51+
52+
# Check async context manager.
53+
async with client.connect(connection_param):
54+
await __check_open(client, is_partition_aware)
55+
__check_closed(client)
56+
57+
# Check standard way.
58+
try:
59+
await client.connect(connection_param)
60+
await __check_open(client, is_partition_aware)
61+
finally:
62+
await client.close()
63+
__check_closed(client)
64+
65+
66+
def __check_open(client, is_partition_aware):
67+
def inner_sync():
68+
if is_partition_aware:
69+
assert client.random_node.alive
70+
else:
71+
all(n.alive for n in client._nodes)
72+
73+
async def inner_async():
74+
if is_partition_aware:
75+
random_node = await client.random_node()
76+
assert random_node.alive
77+
else:
78+
all(n.alive for n in client._nodes)
79+
80+
return inner_sync() if isinstance(client, Client) else inner_async()
81+
82+
83+
def __check_closed(client):
84+
assert all(not n.alive for n in client._nodes)

tests/security/test_auth.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
#
1616
import pytest
1717

18+
from pygridgain import Client, AioClient
1819
from pygridgain.exceptions import AuthenticationError
19-
from tests.util import start_ignite_gen, clear_ignite_work_dir, get_client, get_client_async
20+
from tests.util import start_ignite_gen, clear_ignite_work_dir
2021

2122
DEFAULT_IGNITE_USERNAME = 'ignite'
2223
DEFAULT_IGNITE_PASSWORD = 'ignite'
@@ -41,21 +42,16 @@ def cleanup():
4142

4243
def test_auth_success(with_ssl, ssl_params):
4344
ssl_params['use_ssl'] = with_ssl
44-
45-
with get_client(username=DEFAULT_IGNITE_USERNAME, password=DEFAULT_IGNITE_PASSWORD, **ssl_params) as client:
46-
client.connect("127.0.0.1", 10801)
47-
45+
client = Client(username=DEFAULT_IGNITE_USERNAME, password=DEFAULT_IGNITE_PASSWORD, **ssl_params)
46+
with client.connect("127.0.0.1", 10801):
4847
assert all(node.alive for node in client._nodes)
4948

5049

5150
@pytest.mark.asyncio
5251
async def test_auth_success_async(with_ssl, ssl_params):
5352
ssl_params['use_ssl'] = with_ssl
54-
55-
async with get_client_async(username=DEFAULT_IGNITE_USERNAME, password=DEFAULT_IGNITE_PASSWORD,
56-
**ssl_params) as client:
57-
await client.connect("127.0.0.1", 10801)
58-
53+
client = AioClient(username=DEFAULT_IGNITE_USERNAME, password=DEFAULT_IGNITE_PASSWORD, **ssl_params)
54+
async with client.connect("127.0.0.1", 10801):
5955
assert all(node.alive for node in client._nodes)
6056

6157

@@ -74,8 +70,9 @@ def test_auth_failed(username, password, with_ssl, ssl_params):
7470
ssl_params['use_ssl'] = with_ssl
7571

7672
with pytest.raises(AuthenticationError):
77-
with get_client(username=username, password=password, **ssl_params) as client:
78-
client.connect("127.0.0.1", 10801)
73+
client = Client(username=username, password=password, **ssl_params)
74+
with client.connect("127.0.0.1", 10801):
75+
pass
7976

8077

8178
@pytest.mark.parametrize(
@@ -87,5 +84,6 @@ async def test_auth_failed_async(username, password, with_ssl, ssl_params):
8784
ssl_params['use_ssl'] = with_ssl
8885

8986
with pytest.raises(AuthenticationError):
90-
async with get_client_async(username=username, password=password, **ssl_params) as client:
91-
await client.connect("127.0.0.1", 10801)
87+
client = AioClient(username=username, password=password, **ssl_params)
88+
async with client.connect("127.0.0.1", 10801):
89+
pass

tests/security/test_ssl.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
#
1616
import pytest
1717

18+
from pygridgain import Client, AioClient
1819
from pygridgain.exceptions import ReconnectError
19-
from tests.util import start_ignite_gen, get_client, get_or_create_cache, get_client_async, get_or_create_cache_async
20+
from tests.util import start_ignite_gen, get_or_create_cache, get_or_create_cache_async
2021

2122

2223
@pytest.fixture(scope='module', autouse=True)
@@ -46,18 +47,16 @@ def __test_connect_ssl(is_async=False, **kwargs):
4647
kwargs['use_ssl'] = True
4748

4849
def inner():
49-
with get_client(**kwargs) as client:
50-
client.connect("127.0.0.1", 10801)
51-
50+
client = Client(**kwargs)
51+
with client.connect("127.0.0.1", 10801):
5252
with get_or_create_cache(client, 'test-cache') as cache:
5353
cache.put(1, 1)
5454

5555
assert cache.get(1) == 1
5656

5757
async def inner_async():
58-
async with get_client_async(**kwargs) as client:
59-
await client.connect("127.0.0.1", 10801)
60-
58+
client = AioClient(**kwargs)
59+
async with client.connect("127.0.0.1", 10801):
6160
async with get_or_create_cache_async(client, 'test-cache') as cache:
6261
await cache.put(1, 1)
6362

@@ -76,13 +75,15 @@ async def inner_async():
7675
@pytest.mark.parametrize('invalid_ssl_params', invalid_params)
7776
def test_connection_error_with_incorrect_config(invalid_ssl_params):
7877
with pytest.raises(ReconnectError):
79-
with get_client(**invalid_ssl_params) as client:
80-
client.connect([("127.0.0.1", 10801)])
78+
client = Client(**invalid_ssl_params)
79+
with client.connect([("127.0.0.1", 10801)]):
80+
pass
8181

8282

8383
@pytest.mark.parametrize('invalid_ssl_params', invalid_params)
8484
@pytest.mark.asyncio
8585
async def test_connection_error_with_incorrect_config_async(invalid_ssl_params):
8686
with pytest.raises(ReconnectError):
87-
async with get_client_async(**invalid_ssl_params) as client:
88-
await client.connect([("127.0.0.1", 10801)])
87+
client = AioClient(**invalid_ssl_params)
88+
async with client.connect([("127.0.0.1", 10801)]):
89+
pass

tests/util.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -36,24 +36,6 @@
3636
from async_generator import asynccontextmanager
3737

3838

39-
@contextlib.contextmanager
40-
def get_client(**kwargs):
41-
client = Client(**kwargs)
42-
try:
43-
yield client
44-
finally:
45-
client.close()
46-
47-
48-
@asynccontextmanager
49-
async def get_client_async(**kwargs):
50-
client = AioClient(**kwargs)
51-
try:
52-
yield client
53-
finally:
54-
await client.close()
55-
56-
5739
@contextlib.contextmanager
5840
def get_or_create_cache(client, cache_name):
5941
cache = client.get_or_create_cache(cache_name)

0 commit comments

Comments
 (0)