Skip to content

Commit 9969b94

Browse files
author
Kyle Weaver
authored
Merge pull request #11671 from ibzib/BEAM-9935
[BEAM-9935] [release-2.21.0] Respect allowed split points in Python.
2 parents 34dabe3 + 82fa39e commit 9969b94

3 files changed

Lines changed: 216 additions & 30 deletions

File tree

model/fn-execution/src/main/proto/beam_fn_api.proto

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,9 @@ message ProcessBundleSplitRequest {
375375

376376
// A set of allowed element indices where the SDK may split. When this is
377377
// empty, there are no constraints on where to split.
378+
// Specifically, the first_residual_element of a split result must be an
379+
// allowed split point, and the last_primary_element must immediately
380+
// preceded an allowed split point.
378381
repeated int64 allowed_split_points = 3;
379382

380383
// (Required for GrpcRead operations) Number of total elements expected

sdks/python/apache_beam/runners/worker/bundle_processor.py

Lines changed: 71 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from __future__ import print_function
2525

2626
import base64
27+
import bisect
2728
import collections
2829
import json
2930
import logging
@@ -216,15 +217,12 @@ def process_encoded(self, encoded_windowed_values):
216217
input_stream, True)
217218
self.output(decoded_value)
218219

219-
def try_split(self, fraction_of_remainder, total_buffer_size):
220+
def try_split(
221+
self, fraction_of_remainder, total_buffer_size, allowed_split_points):
220222
# type: (...) -> Optional[Tuple[int, Optional[operations.SdfSplitResultsPrimary], Optional[operations.SdfSplitResultsResidual], int]]
221223
with self.splitting_lock:
222224
if not self.started:
223225
return None
224-
if total_buffer_size < self.index + 1:
225-
total_buffer_size = self.index + 1
226-
elif self.stop and total_buffer_size > self.stop:
227-
total_buffer_size = self.stop
228226
if self.index == -1:
229227
# We are "finished" with the (non-existent) previous element.
230228
current_element_progress = 1.0
@@ -237,30 +235,72 @@ def try_split(self, fraction_of_remainder, total_buffer_size):
237235
current_element_progress = (
238236
current_element_progress_object.fraction_completed)
239237
# Now figure out where to split.
240-
# The units here (except for keep_of_element_remainder) are all in
241-
# terms of number of (possibly fractional) elements.
242-
remainder = total_buffer_size - self.index - current_element_progress
243-
keep = remainder * fraction_of_remainder
244-
if current_element_progress < 1:
245-
keep_of_element_remainder = keep / (1 - current_element_progress)
246-
# If it's less than what's left of the current element,
247-
# try splitting at the current element.
248-
if keep_of_element_remainder < 1:
249-
split = self.receivers[0].try_split(
250-
keep_of_element_remainder
251-
) # type: Optional[Tuple[operations.SdfSplitResultsPrimary, operations.SdfSplitResultsResidual]]
252-
if split:
253-
element_primary, element_residual = split
254-
self.stop = self.index + 1
255-
return self.index - 1, element_primary, element_residual, self.stop
256-
# Otherwise, split at the closest element boundary.
257-
# pylint: disable=round-builtin
258-
stop_index = (
259-
self.index + max(1, int(round(current_element_progress + keep))))
260-
if stop_index < self.stop:
261-
self.stop = stop_index
262-
return self.stop - 1, None, None, self.stop
263-
return None
238+
split = self._compute_split(
239+
self.index,
240+
current_element_progress,
241+
self.stop,
242+
fraction_of_remainder,
243+
total_buffer_size,
244+
allowed_split_points,
245+
self.receivers[0].try_split)
246+
if split:
247+
self.stop = split[-1]
248+
return split
249+
250+
@staticmethod
251+
def _compute_split(
252+
index,
253+
current_element_progress,
254+
stop,
255+
fraction_of_remainder,
256+
total_buffer_size,
257+
allowed_split_points=(),
258+
try_split=lambda fraction: None):
259+
def is_valid_split_point(index):
260+
return not allowed_split_points or index in allowed_split_points
261+
262+
if total_buffer_size < index + 1:
263+
total_buffer_size = index + 1
264+
elif total_buffer_size > stop:
265+
total_buffer_size = stop
266+
# The units here (except for keep_of_element_remainder) are all in
267+
# terms of number of (possibly fractional) elements.
268+
remainder = total_buffer_size - index - current_element_progress
269+
keep = remainder * fraction_of_remainder
270+
if current_element_progress < 1:
271+
keep_of_element_remainder = keep / (1 - current_element_progress)
272+
# If it's less than what's left of the current element,
273+
# try splitting at the current element.
274+
if (keep_of_element_remainder < 1 and is_valid_split_point(index) and
275+
is_valid_split_point(index + 1)):
276+
split = try_split(
277+
keep_of_element_remainder
278+
) # type: Optional[Tuple[operations.SdfSplitResultsPrimary, operations.SdfSplitResultsResidual]]
279+
if split:
280+
element_primary, element_residual = split
281+
return index - 1, element_primary, element_residual, index + 1
282+
# Otherwise, split at the closest element boundary.
283+
# pylint: disable=round-builtin
284+
stop_index = index + max(1, int(round(current_element_progress + keep)))
285+
if allowed_split_points and stop_index not in allowed_split_points:
286+
# Choose the closest allowed split point.
287+
allowed_split_points = sorted(allowed_split_points)
288+
closest = bisect.bisect(allowed_split_points, stop_index)
289+
if closest == 0:
290+
stop_index = allowed_split_points[0]
291+
elif closest == len(allowed_split_points):
292+
stop_index = allowed_split_points[-1]
293+
else:
294+
prev = allowed_split_points[closest - 1]
295+
next = allowed_split_points[closest]
296+
if index < prev and stop_index - prev < next - stop_index:
297+
stop_index = prev
298+
else:
299+
stop_index = next
300+
if index < stop_index < stop:
301+
return stop_index - 1, None, None, stop_index
302+
else:
303+
return None
264304

265305
def finish(self):
266306
# type: () -> None
@@ -955,7 +995,8 @@ def try_split(self, bundle_split_request):
955995
if desired_split:
956996
split = op.try_split(
957997
desired_split.fraction_of_remainder,
958-
desired_split.estimated_input_elements)
998+
desired_split.estimated_input_elements,
999+
desired_split.allowed_split_points)
9591000
if split:
9601001
(
9611002
primary_end,
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
"""Unit tests for bundle processing."""
19+
# pytype: skip-file
20+
21+
from __future__ import absolute_import
22+
23+
import unittest
24+
25+
from apache_beam.runners.worker.bundle_processor import DataInputOperation
26+
27+
28+
def simple_split(first_residual_index):
29+
return first_residual_index - 1, None, None, first_residual_index
30+
31+
32+
def element_split(frac, index):
33+
return (
34+
index - 1,
35+
'Primary(%0.1f)' % frac,
36+
'Residual(%0.1f)' % (1 - frac),
37+
index + 1)
38+
39+
40+
class SplitTest(unittest.TestCase):
41+
def split(
42+
self,
43+
index,
44+
current_element_progress,
45+
fraction_of_remainder,
46+
buffer_size,
47+
allowed=(),
48+
sdf=False):
49+
return DataInputOperation._compute_split(
50+
index,
51+
current_element_progress,
52+
float('inf'),
53+
fraction_of_remainder,
54+
buffer_size,
55+
allowed_split_points=allowed,
56+
try_split=lambda frac: element_split(frac, 0)[1:3] if sdf else None)
57+
58+
def sdf_split(self, *args, **kwargs):
59+
return self.split(*args, sdf=True, **kwargs)
60+
61+
def test_simple_split(self):
62+
# Split as close to the beginning as possible.
63+
self.assertEqual(self.split(0, 0, 0, 16), simple_split(1))
64+
# The closest split is at 4, even when just above or below it.
65+
self.assertEqual(self.split(0, 0, 0.24, 16), simple_split(4))
66+
self.assertEqual(self.split(0, 0, 0.25, 16), simple_split(4))
67+
self.assertEqual(self.split(0, 0, 0.26, 16), simple_split(4))
68+
# Split the *remainder* in half.
69+
self.assertEqual(self.split(0, 0, 0.5, 16), simple_split(8))
70+
self.assertEqual(self.split(2, 0, 0.5, 16), simple_split(9))
71+
self.assertEqual(self.split(6, 0, 0.5, 16), simple_split(11))
72+
73+
def test_split_with_element_progress(self):
74+
# Progress into the active element influences where the split of the
75+
# remainder falls.
76+
self.assertEqual(self.split(0, 0.5, 0.25, 4), simple_split(1))
77+
self.assertEqual(self.split(0, 0.9, 0.25, 4), simple_split(2))
78+
self.assertEqual(self.split(1, 0.0, 0.25, 4), simple_split(2))
79+
self.assertEqual(self.split(1, 0.1, 0.25, 4), simple_split(2))
80+
81+
def test_split_with_element_allowed_splits(self):
82+
# The desired split point is at 4.
83+
self.assertEqual(
84+
self.split(0, 0, 0.25, 16, allowed=(2, 3, 4, 5)), simple_split(4))
85+
# If we can't split at 4, choose the closest possible split point.
86+
self.assertEqual(
87+
self.split(0, 0, 0.25, 16, allowed=(2, 3, 5)), simple_split(5))
88+
self.assertEqual(
89+
self.split(0, 0, 0.25, 16, allowed=(2, 3, 6)), simple_split(3))
90+
91+
# Also test the case where all possible split points lie above or below
92+
# the desired split point.
93+
self.assertEqual(
94+
self.split(0, 0, 0.25, 16, allowed=(5, 6, 7)), simple_split(5))
95+
self.assertEqual(
96+
self.split(0, 0, 0.25, 16, allowed=(1, 2, 3)), simple_split(3))
97+
98+
# We have progressed beyond all possible split points, so can't split.
99+
self.assertEqual(self.split(5, 0, 0.25, 16, allowed=(1, 2, 3)), None)
100+
101+
def test_sdf_split(self):
102+
# Split between future elements at element boundaries.
103+
self.assertEqual(self.sdf_split(0, 0, 0.51, 4), simple_split(2))
104+
self.assertEqual(self.sdf_split(0, 0, 0.49, 4), simple_split(2))
105+
self.assertEqual(self.sdf_split(0, 0, 0.26, 4), simple_split(1))
106+
self.assertEqual(self.sdf_split(0, 0, 0.25, 4), simple_split(1))
107+
108+
# If the split falls inside the first, splittable element, split there.
109+
self.assertEqual(
110+
self.sdf_split(0, 0, 0.20, 4), (-1, 'Primary(0.8)', 'Residual(0.2)', 1))
111+
# The choice of split depends on the progress into the first element.
112+
self.assertEqual(
113+
self.sdf_split(0, 0, .125, 4), (-1, 'Primary(0.5)', 'Residual(0.5)', 1))
114+
# Here we are far enough into the first element that splitting at 0.2 of the
115+
# remainder falls outside the first element.
116+
self.assertEqual(self.sdf_split(0, .5, 0.2, 4), simple_split(1))
117+
118+
# Verify the above logic when we are partially throug the stream.
119+
self.assertEqual(self.sdf_split(2, 0, 0.6, 4), simple_split(3))
120+
self.assertEqual(self.sdf_split(2, 0.9, 0.6, 4), simple_split(4))
121+
self.assertEqual(
122+
self.sdf_split(2, 0.5, 0.2, 4), (1, 'Primary(0.6)', 'Residual(0.4)', 3))
123+
124+
def test_sdf_split_with_allowed_splits(self):
125+
# This is where we would like to split, when all split points are available.
126+
self.assertEqual(
127+
self.sdf_split(2, 0, 0.2, 5, allowed=(1, 2, 3, 4, 5)),
128+
(1, 'Primary(0.6)', 'Residual(0.4)', 3))
129+
# We can't split element at index 2, because 3 is not a split point.
130+
self.assertEqual(
131+
self.sdf_split(2, 0, 0.2, 5, allowed=(1, 2, 4, 5)), simple_split(4))
132+
# We can't even split element at index 4 as above, because 4 is also not a
133+
# split point.
134+
self.assertEqual(
135+
self.sdf_split(2, 0, 0.2, 5, allowed=(1, 2, 5)), simple_split(5))
136+
# We can't split element at index 2, because 2 is not a split point.
137+
self.assertEqual(
138+
self.sdf_split(2, 0, 0.2, 5, allowed=(1, 3, 4, 5)), simple_split(3))
139+
140+
141+
if __name__ == '__main__':
142+
unittest.main()

0 commit comments

Comments
 (0)