Skip to content

Commit 7931ec0

Browse files
authored
Merge pull request #9342 from [BEAM-7866][BEAM-5148] Cherry-picks mongodb fixes to 2.15.0 release branch
2 parents 45de258 + cc9e966 commit 7931ec0

3 files changed

Lines changed: 436 additions & 112 deletions

File tree

sdks/python/apache_beam/io/mongodbio.py

Lines changed: 183 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -52,21 +52,35 @@
5252
"""
5353

5454
from __future__ import absolute_import
55+
from __future__ import division
5556

5657
import logging
57-
58-
from bson import objectid
59-
from pymongo import MongoClient
60-
from pymongo import ReplaceOne
58+
import struct
6159

6260
import apache_beam as beam
6361
from apache_beam.io import iobase
64-
from apache_beam.io.range_trackers import OffsetRangeTracker
62+
from apache_beam.io.range_trackers import OrderedPositionRangeTracker
6563
from apache_beam.transforms import DoFn
6664
from apache_beam.transforms import PTransform
6765
from apache_beam.transforms import Reshuffle
6866
from apache_beam.utils.annotations import experimental
6967

68+
try:
69+
# Mongodb has its own bundled bson, which is not compatible with bson pakcage.
70+
# (https://github.com/py-bson/bson/issues/82). Try to import objectid and if
71+
# it fails because bson package is installed, MongoDB IO will not work but at
72+
# least rest of the SDK will work.
73+
from bson import objectid
74+
75+
# pymongo also internally depends on bson.
76+
from pymongo import ASCENDING
77+
from pymongo import DESCENDING
78+
from pymongo import MongoClient
79+
from pymongo import ReplaceOne
80+
except ImportError:
81+
objectid = None
82+
logging.warning("Could not find a compatible bson package.")
83+
7084
__all__ = ['ReadFromMongoDB', 'WriteToMongoDB']
7185

7286

@@ -139,50 +153,49 @@ def __init__(self,
139153
self.filter = filter
140154
self.projection = projection
141155
self.spec = extra_client_params
142-
self.doc_count = self._get_document_count()
143-
self.avg_doc_size = self._get_avg_document_size()
144-
self.client = None
145156

146157
def estimate_size(self):
147-
return self.avg_doc_size * self.doc_count
158+
with MongoClient(self.uri, **self.spec) as client:
159+
return client[self.db].command('collstats', self.coll).get('size')
148160

149161
def split(self, desired_bundle_size, start_position=None, stop_position=None):
150-
# use document cursor index as the start and stop positions
151-
if start_position is None:
152-
start_position = 0
153-
if stop_position is None:
154-
stop_position = self.doc_count
162+
start_position, stop_position = self._replace_none_positions(
163+
start_position, stop_position)
155164

156-
# get an estimate on how many documents should be included in a split batch
157-
desired_bundle_count = desired_bundle_size // self.avg_doc_size
165+
desired_bundle_size_in_mb = desired_bundle_size // 1024 // 1024
166+
split_keys = self._get_split_keys(desired_bundle_size_in_mb, start_position,
167+
stop_position)
158168

159169
bundle_start = start_position
160-
while bundle_start < stop_position:
161-
bundle_end = min(stop_position, bundle_start + desired_bundle_count)
162-
yield iobase.SourceBundle(weight=bundle_end - bundle_start,
170+
for split_key_id in split_keys:
171+
if bundle_start >= stop_position:
172+
break
173+
bundle_end = min(stop_position, split_key_id)
174+
yield iobase.SourceBundle(weight=desired_bundle_size_in_mb,
163175
source=self,
164176
start_position=bundle_start,
165177
stop_position=bundle_end)
166178
bundle_start = bundle_end
179+
# add range of last split_key to stop_position
180+
if bundle_start < stop_position:
181+
yield iobase.SourceBundle(weight=desired_bundle_size_in_mb,
182+
source=self,
183+
start_position=bundle_start,
184+
stop_position=stop_position)
167185

168186
def get_range_tracker(self, start_position, stop_position):
169-
if start_position is None:
170-
start_position = 0
171-
if stop_position is None:
172-
stop_position = self.doc_count
173-
return OffsetRangeTracker(start_position, stop_position)
187+
start_position, stop_position = self._replace_none_positions(
188+
start_position, stop_position)
189+
return _ObjectIdRangeTracker(start_position, stop_position)
174190

175191
def read(self, range_tracker):
176192
with MongoClient(self.uri, **self.spec) as client:
177-
# docs is a MongoDB Cursor
178-
docs = client[self.db][self.coll].find(
179-
filter=self.filter, projection=self.projection
180-
)[range_tracker.start_position():range_tracker.stop_position()]
181-
for index in range(range_tracker.start_position(),
182-
range_tracker.stop_position()):
183-
if not range_tracker.try_claim(index):
193+
all_filters = self._merge_id_filter(range_tracker)
194+
docs_cursor = client[self.db][self.coll].find(filter=all_filters)
195+
for doc in docs_cursor:
196+
if not range_tracker.try_claim(doc['_id']):
184197
return
185-
yield docs[index - range_tracker.start_position()]
198+
yield doc
186199

187200
def display_data(self):
188201
res = super(_BoundedMongoSource, self).display_data()
@@ -194,18 +207,146 @@ def display_data(self):
194207
res['mongo_client_spec'] = self.spec
195208
return res
196209

197-
def _get_avg_document_size(self):
210+
def _get_split_keys(self, desired_chunk_size_in_mb, start_pos, end_pos):
211+
# calls mongodb splitVector command to get document ids at split position
212+
# for desired bundle size, if desired chunk size smaller than 1mb, use
213+
# mongodb default split size of 1mb.
214+
if desired_chunk_size_in_mb < 1:
215+
desired_chunk_size_in_mb = 1
216+
if start_pos >= end_pos:
217+
# single document not splittable
218+
return []
198219
with MongoClient(self.uri, **self.spec) as client:
199-
size = client[self.db].command('collstats', self.coll).get('avgObjSize')
200-
if size is None or size <= 0:
201-
raise ValueError(
202-
'Collection %s not found or average doc size is '
203-
'incorrect', self.coll)
204-
return size
205-
206-
def _get_document_count(self):
220+
name_space = '%s.%s' % (self.db, self.coll)
221+
return (client[self.db].command(
222+
'splitVector',
223+
name_space,
224+
keyPattern={'_id': 1}, # Ascending index
225+
min={'_id': start_pos},
226+
max={'_id': end_pos},
227+
maxChunkSize=desired_chunk_size_in_mb)['splitKeys'])
228+
229+
def _merge_id_filter(self, range_tracker):
230+
# Merge the default filter with refined _id field range of range_tracker.
231+
# see more at https://docs.mongodb.com/manual/reference/operator/query/and/
232+
all_filters = {
233+
'$and': [
234+
self.filter.copy(),
235+
# add additional range filter to query. $gte specifies start
236+
# position(inclusive) and $lt specifies the end position(exclusive),
237+
# see more at
238+
# https://docs.mongodb.com/manual/reference/operator/query/gte/ and
239+
# https://docs.mongodb.com/manual/reference/operator/query/lt/
240+
{
241+
'_id': {
242+
'$gte': range_tracker.start_position(),
243+
'$lt': range_tracker.stop_position()
244+
}
245+
},
246+
]
247+
}
248+
249+
return all_filters
250+
251+
def _get_head_document_id(self, sort_order):
207252
with MongoClient(self.uri, **self.spec) as client:
208-
return max(client[self.db][self.coll].count_documents(self.filter), 0)
253+
cursor = client[self.db][self.coll].find(filter={}, projection=[]).sort([
254+
('_id', sort_order)
255+
]).limit(1)
256+
try:
257+
return cursor[0]['_id']
258+
except IndexError:
259+
raise ValueError('Empty Mongodb collection')
260+
261+
def _replace_none_positions(self, start_position, stop_position):
262+
if start_position is None:
263+
start_position = self._get_head_document_id(ASCENDING)
264+
if stop_position is None:
265+
last_doc_id = self._get_head_document_id(DESCENDING)
266+
# increment last doc id binary value by 1 to make sure the last document
267+
# is not excluded
268+
stop_position = _ObjectIdHelper.increment_id(last_doc_id, 1)
269+
return start_position, stop_position
270+
271+
272+
class _ObjectIdHelper(object):
273+
"""A Utility class to manipulate bson object ids."""
274+
275+
@classmethod
276+
def id_to_int(cls, id):
277+
"""
278+
Args:
279+
id: ObjectId required for each MongoDB document _id field.
280+
281+
Returns: Converted integer value of ObjectId's 12 bytes binary value.
282+
283+
"""
284+
# converts object id binary to integer
285+
# id object is bytes type with size of 12
286+
ints = struct.unpack('>III', id.binary)
287+
return (ints[0] << 64) + (ints[1] << 32) + ints[2]
288+
289+
@classmethod
290+
def int_to_id(cls, number):
291+
"""
292+
Args:
293+
number(int): The integer value to be used to convert to ObjectId.
294+
295+
Returns: The ObjectId that has the 12 bytes binary converted from the
296+
integer value.
297+
298+
"""
299+
# converts integer value to object id. Int value should be less than
300+
# (2 ^ 96) so it can be convert to 12 bytes required by object id.
301+
if number < 0 or number >= (1 << 96):
302+
raise ValueError('number value must be within [0, %s)' % (1 << 96))
303+
ints = [(number & 0xffffffff0000000000000000) >> 64,
304+
(number & 0x00000000ffffffff00000000) >> 32,
305+
number & 0x0000000000000000ffffffff]
306+
307+
bytes = struct.pack('>III', *ints)
308+
return objectid.ObjectId(bytes)
309+
310+
@classmethod
311+
def increment_id(cls, object_id, inc):
312+
"""
313+
Args:
314+
object_id: The ObjectId to change.
315+
inc(int): The incremental int value to be added to ObjectId.
316+
317+
Returns:
318+
319+
"""
320+
# increment object_id binary value by inc value and return new object id.
321+
id_number = _ObjectIdHelper.id_to_int(object_id)
322+
new_number = id_number + inc
323+
if new_number < 0 or new_number >= (1 << 96):
324+
raise ValueError('invalid incremental, inc value must be within ['
325+
'%s, %s)' % (0 - id_number, 1 << 96 - id_number))
326+
return _ObjectIdHelper.int_to_id(new_number)
327+
328+
329+
class _ObjectIdRangeTracker(OrderedPositionRangeTracker):
330+
"""RangeTracker for tracking mongodb _id of bson ObjectId type."""
331+
332+
def position_to_fraction(self, pos, start, end):
333+
pos_number = _ObjectIdHelper.id_to_int(pos)
334+
start_number = _ObjectIdHelper.id_to_int(start)
335+
end_number = _ObjectIdHelper.id_to_int(end)
336+
return (pos_number - start_number) / (end_number - start_number)
337+
338+
def fraction_to_position(self, fraction, start, end):
339+
start_number = _ObjectIdHelper.id_to_int(start)
340+
end_number = _ObjectIdHelper.id_to_int(end)
341+
total = end_number - start_number
342+
pos = int(total * fraction + start_number)
343+
# make sure split position is larger than start position and smaller than
344+
# end position.
345+
if pos <= start_number:
346+
return _ObjectIdHelper.increment_id(start, 1)
347+
if pos >= end_number:
348+
return _ObjectIdHelper.increment_id(end, -1)
349+
return _ObjectIdHelper.int_to_id(pos)
209350

210351

211352
@experimental()

sdks/python/apache_beam/io/mongodbio_it_test.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,32 +42,33 @@ def run(argv=None):
4242
default=default_coll,
4343
help='mongo uri string for connection')
4444
parser.add_argument('--num_documents',
45-
default=1000,
45+
default=100000,
4646
help='The expected number of documents to be generated '
4747
'for write or read',
4848
type=int)
4949
parser.add_argument('--batch_size',
50-
default=100,
50+
default=10000,
5151
help=('batch size for writing to mongodb'))
5252
known_args, pipeline_args = parser.parse_known_args(argv)
5353

5454
# Test Write to MongoDB
5555
with TestPipeline(options=PipelineOptions(pipeline_args)) as p:
56+
start_time = time.time()
5657
logging.info('Writing %d documents to mongodb' % known_args.num_documents)
5758
docs = [{
5859
'number': x,
5960
'number_mod_2': x % 2,
6061
'number_mod_3': x % 3
6162
} for x in range(known_args.num_documents)]
6263

63-
start_time = time.time()
6464
_ = p | 'Create documents' >> beam.Create(docs) \
6565
| 'WriteToMongoDB' >> beam.io.WriteToMongoDB(known_args.mongo_uri,
6666
known_args.mongo_db,
6767
known_args.mongo_coll,
6868
known_args.batch_size)
69-
logging.info('Writing %d documents to mongodb finished in %.3f seconds' %
70-
(known_args.num_documents, time.time() - start_time))
69+
elapsed = time.time() - start_time
70+
logging.info('Writing %d documents to mongodb finished in %.3f seconds' %
71+
(known_args.num_documents, elapsed))
7172

7273
# Test Read from MongoDB
7374
with TestPipeline(options=PipelineOptions(pipeline_args)) as p:
@@ -80,11 +81,12 @@ def run(argv=None):
8081
known_args.mongo_coll,
8182
projection=['number']) \
8283
| 'Map' >> beam.Map(lambda doc: doc['number'])
83-
8484
assert_that(
8585
r, equal_to([number for number in range(known_args.num_documents)]))
86-
logging.info('Read %d documents from mongodb finished in %.3f seconds' %
87-
(known_args.num_documents, time.time() - start_time))
86+
87+
elapsed = time.time() - start_time
88+
logging.info('Read %d documents from mongodb finished in %.3f seconds' %
89+
(known_args.num_documents, elapsed))
8890

8991

9092
if __name__ == "__main__":

0 commit comments

Comments
 (0)