Skip to content

Commit 4b30553

Browse files
udimyifanzou
authored andcommitted
[BEAM-7860] Python Datastore: fix key sort order
This is a regression from the v1 client.
1 parent 7c3e5ae commit 4b30553

3 files changed

Lines changed: 110 additions & 8 deletions

File tree

sdks/python/apache_beam/io/gcp/datastore/v1/query_splitter_test.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,11 +171,16 @@ def check_get_splits(self, query, num_splits, num_entities, batch_size):
171171
batch_size: the number of entities returned by fake datastore in one req.
172172
"""
173173

174-
# Test for both random long ids and string ids.
175-
id_or_name = [True, False]
174+
# Test for random long ids, string ids, and a mix of both.
175+
id_or_name = [True, False, None]
176176

177177
for id_type in id_or_name:
178-
entities = fake_datastore.create_entities(num_entities, id_type)
178+
if id_type is None:
179+
entities = fake_datastore.create_entities(num_entities, False)
180+
entities.extend(fake_datastore.create_entities(num_entities, True))
181+
num_entities *= 2
182+
else:
183+
entities = fake_datastore.create_entities(num_entities, id_type)
179184
mock_datastore = MagicMock()
180185
# Assign a fake run_query method as a side_effect to the mock.
181186
mock_datastore.run_query.side_effect = \

sdks/python/apache_beam/io/gcp/datastore/v1new/query_splitter.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626
from builtins import range
2727
from builtins import round
2828

29+
from past.builtins import long
30+
from past.builtins import unicode
31+
2932
from apache_beam.io.gcp.datastore.v1new import types
3033

3134
__all__ = ['QuerySplitterError', 'SplitNotPossibleError', 'get_splits']
@@ -123,10 +126,59 @@ def _create_scatter_query(query, num_splits):
123126
return scatter_query
124127

125128

129+
class IdOrName(object):
130+
"""Represents an ID or name of a Datastore key,
131+
132+
Implements sort ordering: by ID, then by name, keys with IDs before those
133+
with names.
134+
"""
135+
def __init__(self, id_or_name):
136+
self.id_or_name = id_or_name
137+
if isinstance(id_or_name, (str, unicode)):
138+
self.id = None
139+
self.name = id_or_name
140+
elif isinstance(id_or_name, (int, long)):
141+
self.id = id_or_name
142+
self.name = None
143+
else:
144+
raise TypeError('Unexpected type of id_or_name: %s' % id_or_name)
145+
146+
def __lt__(self, other):
147+
if not isinstance(other, IdOrName):
148+
return super(IdOrName, self).__lt__(other)
149+
150+
if self.id is not None:
151+
if other.id is None:
152+
return True
153+
else:
154+
return self.id < other.id
155+
156+
if other.id is not None:
157+
return False
158+
159+
return self.name < other.name
160+
161+
def __eq__(self, other):
162+
if not isinstance(other, IdOrName):
163+
return super(IdOrName, self).__eq__(other)
164+
return self.id == other.id and self.name == other.name
165+
166+
def __hash__(self):
167+
return hash((self.id, self.other))
168+
169+
126170
def client_key_sort_key(client_key):
127171
"""Key function for sorting lists of ``google.cloud.datastore.key.Key``."""
128-
return [client_key.project, client_key.namespace or ''] + [
129-
str(element) for element in client_key.flat_path]
172+
sort_key = [client_key.project, client_key.namespace or '']
173+
# A key path is made up of (kind, id_or_name) pairs. The last pair might be
174+
# missing an id_or_name.
175+
flat_path = list(client_key.flat_path)
176+
while flat_path:
177+
sort_key.append(flat_path.pop(0)) # kind
178+
if flat_path:
179+
sort_key.append(IdOrName(flat_path.pop(0)))
180+
181+
return sort_key
130182

131183

132184
def _get_scatter_keys(client, query, num_splits):

sdks/python/apache_beam/io/gcp/datastore/v1new/query_splitter_test.py

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,16 @@ def check_get_splits(self, query, num_splits, num_entities,
103103
unused_batch_size: ignored in v1new since query results are entirely
104104
handled by the Datastore client.
105105
"""
106-
# Test for both random long ids and string ids.
107-
for id_or_name in [True, False]:
108-
client_entities = helper.create_client_entities(num_entities, id_or_name)
106+
# Test for random long ids, string ids, and a mix of both.
107+
for id_or_name in [True, False, None]:
108+
if id_or_name is None:
109+
client_entities = helper.create_client_entities(num_entities, False)
110+
client_entities.extend(helper.create_client_entities(num_entities,
111+
True))
112+
num_entities *= 2
113+
else:
114+
client_entities = helper.create_client_entities(num_entities,
115+
id_or_name)
109116

110117
mock_client = mock.MagicMock()
111118
mock_client_query = mock.MagicMock()
@@ -154,6 +161,19 @@ def check_get_splits(self, query, num_splits, num_entities,
154161
if lt_key is None:
155162
last_query_seen = True
156163

164+
def test_id_or_name(self):
165+
id_ = query_splitter.IdOrName(1)
166+
self.assertEqual(1, id_.id)
167+
self.assertIsNone(id_.name)
168+
name = query_splitter.IdOrName('1')
169+
self.assertIsNone(name.id)
170+
self.assertEqual('1', name.name)
171+
self.assertEqual(query_splitter.IdOrName(1), query_splitter.IdOrName(1))
172+
self.assertEqual(query_splitter.IdOrName('1'), query_splitter.IdOrName('1'))
173+
self.assertLess(query_splitter.IdOrName(2), query_splitter.IdOrName('1'))
174+
self.assertLess(query_splitter.IdOrName(1), query_splitter.IdOrName(2))
175+
self.assertLess(query_splitter.IdOrName('1'), query_splitter.IdOrName('2'))
176+
157177
def test_client_key_sort_key(self):
158178
k = key.Key('kind1', 1, project=self._PROJECT, namespace=self._NAMESPACE)
159179
k2 = key.Key('kind2', 'a', parent=k)
@@ -165,6 +185,31 @@ def test_client_key_sort_key(self):
165185
keys.sort(key=query_splitter.client_key_sort_key)
166186
self.assertEqual(expected_sort, keys)
167187

188+
def test_client_key_sort_key_ids(self):
189+
k1 = key.Key('kind', 2, project=self._PROJECT)
190+
k2 = key.Key('kind', 1, project=self._PROJECT)
191+
keys = [k1, k2]
192+
expected_sort = [k2, k1]
193+
keys.sort(key=query_splitter.client_key_sort_key)
194+
self.assertEqual(expected_sort, keys)
195+
196+
def test_client_key_sort_key_names(self):
197+
k1 = key.Key('kind', '2', project=self._PROJECT)
198+
k2 = key.Key('kind', '1', project=self._PROJECT)
199+
keys = [k1, k2]
200+
expected_sort = [k2, k1]
201+
keys.sort(key=query_splitter.client_key_sort_key)
202+
self.assertEqual(expected_sort, keys)
203+
204+
def test_client_key_sort_key_ids_vs_names(self):
205+
# Keys with IDs always come before keys with names.
206+
k1 = key.Key('kind', '1', project=self._PROJECT)
207+
k2 = key.Key('kind', 2, project=self._PROJECT)
208+
keys = [k1, k2]
209+
expected_sort = [k2, k1]
210+
keys.sort(key=query_splitter.client_key_sort_key)
211+
self.assertEqual(expected_sort, keys)
212+
168213

169214
# Hide base class from collection by nose.
170215
del QuerySplitterTestBase

0 commit comments

Comments
 (0)