5252"""
5353
5454from __future__ import absolute_import
55+ from __future__ import division
5556
5657import logging
57-
58- from bson import objectid
59- from pymongo import MongoClient
60- from pymongo import ReplaceOne
58+ import struct
6159
6260import apache_beam as beam
6361from apache_beam .io import iobase
64- from apache_beam .io .range_trackers import OffsetRangeTracker
62+ from apache_beam .io .range_trackers import OrderedPositionRangeTracker
6563from apache_beam .transforms import DoFn
6664from apache_beam .transforms import PTransform
6765from apache_beam .transforms import Reshuffle
6866from apache_beam .utils .annotations import experimental
6967
68+ try :
69+ # Mongodb has its own bundled bson, which is not compatible with bson pakcage.
70+ # (https://github.com/py-bson/bson/issues/82). Try to import objectid and if
71+ # it fails because bson package is installed, MongoDB IO will not work but at
72+ # least rest of the SDK will work.
73+ from bson import objectid
74+
75+ # pymongo also internally depends on bson.
76+ from pymongo import ASCENDING
77+ from pymongo import DESCENDING
78+ from pymongo import MongoClient
79+ from pymongo import ReplaceOne
80+ except ImportError :
81+ objectid = None
82+ logging .warning ("Could not find a compatible bson package." )
83+
7084__all__ = ['ReadFromMongoDB' , 'WriteToMongoDB' ]
7185
7286
@@ -139,50 +153,49 @@ def __init__(self,
139153 self .filter = filter
140154 self .projection = projection
141155 self .spec = extra_client_params
142- self .doc_count = self ._get_document_count ()
143- self .avg_doc_size = self ._get_avg_document_size ()
144- self .client = None
145156
146157 def estimate_size (self ):
147- return self .avg_doc_size * self .doc_count
158+ with MongoClient (self .uri , ** self .spec ) as client :
159+ return client [self .db ].command ('collstats' , self .coll ).get ('size' )
148160
149161 def split (self , desired_bundle_size , start_position = None , stop_position = None ):
150- # use document cursor index as the start and stop positions
151- if start_position is None :
152- start_position = 0
153- if stop_position is None :
154- stop_position = self .doc_count
162+ start_position , stop_position = self ._replace_none_positions (
163+ start_position , stop_position )
155164
156- # get an estimate on how many documents should be included in a split batch
157- desired_bundle_count = desired_bundle_size // self .avg_doc_size
165+ desired_bundle_size_in_mb = desired_bundle_size // 1024 // 1024
166+ split_keys = self ._get_split_keys (desired_bundle_size_in_mb , start_position ,
167+ stop_position )
158168
159169 bundle_start = start_position
160- while bundle_start < stop_position :
161- bundle_end = min (stop_position , bundle_start + desired_bundle_count )
162- yield iobase .SourceBundle (weight = bundle_end - bundle_start ,
170+ for split_key_id in split_keys :
171+ if bundle_start >= stop_position :
172+ break
173+ bundle_end = min (stop_position , split_key_id )
174+ yield iobase .SourceBundle (weight = desired_bundle_size_in_mb ,
163175 source = self ,
164176 start_position = bundle_start ,
165177 stop_position = bundle_end )
166178 bundle_start = bundle_end
179+ # add range of last split_key to stop_position
180+ if bundle_start < stop_position :
181+ yield iobase .SourceBundle (weight = desired_bundle_size_in_mb ,
182+ source = self ,
183+ start_position = bundle_start ,
184+ stop_position = stop_position )
167185
168186 def get_range_tracker (self , start_position , stop_position ):
169- if start_position is None :
170- start_position = 0
171- if stop_position is None :
172- stop_position = self .doc_count
173- return OffsetRangeTracker (start_position , stop_position )
187+ start_position , stop_position = self ._replace_none_positions (
188+ start_position , stop_position )
189+ return _ObjectIdRangeTracker (start_position , stop_position )
174190
175191 def read (self , range_tracker ):
176192 with MongoClient (self .uri , ** self .spec ) as client :
177- # docs is a MongoDB Cursor
178- docs = client [self .db ][self .coll ].find (
179- filter = self .filter , projection = self .projection
180- )[range_tracker .start_position ():range_tracker .stop_position ()]
181- for index in range (range_tracker .start_position (),
182- range_tracker .stop_position ()):
183- if not range_tracker .try_claim (index ):
193+ all_filters = self ._merge_id_filter (range_tracker )
194+ docs_cursor = client [self .db ][self .coll ].find (filter = all_filters )
195+ for doc in docs_cursor :
196+ if not range_tracker .try_claim (doc ['_id' ]):
184197 return
185- yield docs [ index - range_tracker . start_position ()]
198+ yield doc
186199
187200 def display_data (self ):
188201 res = super (_BoundedMongoSource , self ).display_data ()
@@ -194,18 +207,146 @@ def display_data(self):
194207 res ['mongo_client_spec' ] = self .spec
195208 return res
196209
197- def _get_avg_document_size (self ):
210+ def _get_split_keys (self , desired_chunk_size_in_mb , start_pos , end_pos ):
211+ # calls mongodb splitVector command to get document ids at split position
212+ # for desired bundle size, if desired chunk size smaller than 1mb, use
213+ # mongodb default split size of 1mb.
214+ if desired_chunk_size_in_mb < 1 :
215+ desired_chunk_size_in_mb = 1
216+ if start_pos >= end_pos :
217+ # single document not splittable
218+ return []
198219 with MongoClient (self .uri , ** self .spec ) as client :
199- size = client [self .db ].command ('collstats' , self .coll ).get ('avgObjSize' )
200- if size is None or size <= 0 :
201- raise ValueError (
202- 'Collection %s not found or average doc size is '
203- 'incorrect' , self .coll )
204- return size
205-
206- def _get_document_count (self ):
220+ name_space = '%s.%s' % (self .db , self .coll )
221+ return (client [self .db ].command (
222+ 'splitVector' ,
223+ name_space ,
224+ keyPattern = {'_id' : 1 }, # Ascending index
225+ min = {'_id' : start_pos },
226+ max = {'_id' : end_pos },
227+ maxChunkSize = desired_chunk_size_in_mb )['splitKeys' ])
228+
229+ def _merge_id_filter (self , range_tracker ):
230+ # Merge the default filter with refined _id field range of range_tracker.
231+ # see more at https://docs.mongodb.com/manual/reference/operator/query/and/
232+ all_filters = {
233+ '$and' : [
234+ self .filter .copy (),
235+ # add additional range filter to query. $gte specifies start
236+ # position(inclusive) and $lt specifies the end position(exclusive),
237+ # see more at
238+ # https://docs.mongodb.com/manual/reference/operator/query/gte/ and
239+ # https://docs.mongodb.com/manual/reference/operator/query/lt/
240+ {
241+ '_id' : {
242+ '$gte' : range_tracker .start_position (),
243+ '$lt' : range_tracker .stop_position ()
244+ }
245+ },
246+ ]
247+ }
248+
249+ return all_filters
250+
251+ def _get_head_document_id (self , sort_order ):
207252 with MongoClient (self .uri , ** self .spec ) as client :
208- return max (client [self .db ][self .coll ].count_documents (self .filter ), 0 )
253+ cursor = client [self .db ][self .coll ].find (filter = {}, projection = []).sort ([
254+ ('_id' , sort_order )
255+ ]).limit (1 )
256+ try :
257+ return cursor [0 ]['_id' ]
258+ except IndexError :
259+ raise ValueError ('Empty Mongodb collection' )
260+
261+ def _replace_none_positions (self , start_position , stop_position ):
262+ if start_position is None :
263+ start_position = self ._get_head_document_id (ASCENDING )
264+ if stop_position is None :
265+ last_doc_id = self ._get_head_document_id (DESCENDING )
266+ # increment last doc id binary value by 1 to make sure the last document
267+ # is not excluded
268+ stop_position = _ObjectIdHelper .increment_id (last_doc_id , 1 )
269+ return start_position , stop_position
270+
271+
272+ class _ObjectIdHelper (object ):
273+ """A Utility class to manipulate bson object ids."""
274+
275+ @classmethod
276+ def id_to_int (cls , id ):
277+ """
278+ Args:
279+ id: ObjectId required for each MongoDB document _id field.
280+
281+ Returns: Converted integer value of ObjectId's 12 bytes binary value.
282+
283+ """
284+ # converts object id binary to integer
285+ # id object is bytes type with size of 12
286+ ints = struct .unpack ('>III' , id .binary )
287+ return (ints [0 ] << 64 ) + (ints [1 ] << 32 ) + ints [2 ]
288+
289+ @classmethod
290+ def int_to_id (cls , number ):
291+ """
292+ Args:
293+ number(int): The integer value to be used to convert to ObjectId.
294+
295+ Returns: The ObjectId that has the 12 bytes binary converted from the
296+ integer value.
297+
298+ """
299+ # converts integer value to object id. Int value should be less than
300+ # (2 ^ 96) so it can be convert to 12 bytes required by object id.
301+ if number < 0 or number >= (1 << 96 ):
302+ raise ValueError ('number value must be within [0, %s)' % (1 << 96 ))
303+ ints = [(number & 0xffffffff0000000000000000 ) >> 64 ,
304+ (number & 0x00000000ffffffff00000000 ) >> 32 ,
305+ number & 0x0000000000000000ffffffff ]
306+
307+ bytes = struct .pack ('>III' , * ints )
308+ return objectid .ObjectId (bytes )
309+
310+ @classmethod
311+ def increment_id (cls , object_id , inc ):
312+ """
313+ Args:
314+ object_id: The ObjectId to change.
315+ inc(int): The incremental int value to be added to ObjectId.
316+
317+ Returns:
318+
319+ """
320+ # increment object_id binary value by inc value and return new object id.
321+ id_number = _ObjectIdHelper .id_to_int (object_id )
322+ new_number = id_number + inc
323+ if new_number < 0 or new_number >= (1 << 96 ):
324+ raise ValueError ('invalid incremental, inc value must be within ['
325+ '%s, %s)' % (0 - id_number , 1 << 96 - id_number ))
326+ return _ObjectIdHelper .int_to_id (new_number )
327+
328+
329+ class _ObjectIdRangeTracker (OrderedPositionRangeTracker ):
330+ """RangeTracker for tracking mongodb _id of bson ObjectId type."""
331+
332+ def position_to_fraction (self , pos , start , end ):
333+ pos_number = _ObjectIdHelper .id_to_int (pos )
334+ start_number = _ObjectIdHelper .id_to_int (start )
335+ end_number = _ObjectIdHelper .id_to_int (end )
336+ return (pos_number - start_number ) / (end_number - start_number )
337+
338+ def fraction_to_position (self , fraction , start , end ):
339+ start_number = _ObjectIdHelper .id_to_int (start )
340+ end_number = _ObjectIdHelper .id_to_int (end )
341+ total = end_number - start_number
342+ pos = int (total * fraction + start_number )
343+ # make sure split position is larger than start position and smaller than
344+ # end position.
345+ if pos <= start_number :
346+ return _ObjectIdHelper .increment_id (start , 1 )
347+ if pos >= end_number :
348+ return _ObjectIdHelper .increment_id (end , - 1 )
349+ return _ObjectIdHelper .int_to_id (pos )
209350
210351
211352@experimental ()
0 commit comments