Skip to content

Commit 34dabe3

Browse files
authored
Merge pull request #11667: [BEAM-9947] Store updated ParDoPayload for length-prefixed timer coders
Backport of #11658.
2 parents 75efec2 + 625bc3f commit 34dabe3

4 files changed

Lines changed: 77 additions & 44 deletions

File tree

runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ModelCoders.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,11 @@ public abstract static class WindowedValueCoderComponents {
115115
}
116116

117117
public static KvCoderComponents getKvCoderComponents(Coder coder) {
118-
checkArgument(KV_CODER_URN.equals(coder.getSpec().getUrn()));
118+
checkArgument(
119+
KV_CODER_URN.equals(coder.getSpec().getUrn()),
120+
"Provided coder %s is not of type %s",
121+
coder.getSpec().getUrn(),
122+
KV_CODER_URN);
119123
return new AutoValue_ModelCoders_KvCoderComponents(
120124
coder.getComponentCoderIds(0), coder.getComponentCoderIds(1));
121125
}

runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptors.java

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -396,18 +396,16 @@ private static Map<String, Map<String, TimerSpec>> forTimerSpecs(
396396
.getTransformsOrThrow(timerReference.transform().getId())
397397
.toBuilder()
398398
.getSpecBuilder();
399-
RunnerApi.ParDoPayload updatedPayload =
400-
RunnerApi.ParDoPayload.parseFrom(updatedSpec.getPayload());
401-
updatedPayload
402-
.toBuilder()
403-
.putTimerFamilySpecs(
404-
timerReference.localName(),
405-
updatedPayload
406-
.getTimerFamilySpecsOrThrow(timerReference.localName())
407-
.toBuilder()
408-
.setTimerFamilyCoderId(sdkCoderId)
409-
.build());
410-
updatedSpec.setPayload(updatedPayload.toByteString());
399+
RunnerApi.ParDoPayload.Builder updatedPayload =
400+
RunnerApi.ParDoPayload.parseFrom(updatedSpec.getPayload()).toBuilder();
401+
updatedPayload.putTimerFamilySpecs(
402+
timerReference.localName(),
403+
updatedPayload
404+
.getTimerFamilySpecsOrThrow(timerReference.localName())
405+
.toBuilder()
406+
.setTimerFamilyCoderId(sdkCoderId)
407+
.build());
408+
updatedSpec.setPayload(updatedPayload.build().toByteString());
411409
components.putTransforms(
412410
timerReference.transform().getId(),
413411
// Since a transform can have more then one timer, update the transform inside components

runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptorsTest.java

Lines changed: 45 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import org.apache.beam.runners.core.construction.graph.FusedPipeline;
3535
import org.apache.beam.runners.core.construction.graph.GreedyPipelineFuser;
3636
import org.apache.beam.runners.core.construction.graph.PipelineNode;
37+
import org.apache.beam.runners.core.construction.graph.TimerReference;
3738
import org.apache.beam.sdk.Pipeline;
3839
import org.apache.beam.sdk.coders.Coder;
3940
import org.apache.beam.sdk.coders.KvCoder;
@@ -42,6 +43,10 @@
4243
import org.apache.beam.sdk.state.BagState;
4344
import org.apache.beam.sdk.state.StateSpec;
4445
import org.apache.beam.sdk.state.StateSpecs;
46+
import org.apache.beam.sdk.state.TimeDomain;
47+
import org.apache.beam.sdk.state.Timer;
48+
import org.apache.beam.sdk.state.TimerSpec;
49+
import org.apache.beam.sdk.state.TimerSpecs;
4550
import org.apache.beam.sdk.transforms.DoFn;
4651
import org.apache.beam.sdk.transforms.GroupByKey;
4752
import org.apache.beam.sdk.transforms.Impulse;
@@ -59,7 +64,7 @@ public class ProcessBundleDescriptorsTest implements Serializable {
5964
* LengthPrefixCoder.
6065
*/
6166
@Test
62-
public void testWrapKeyCoderOfStatefulExecutableStageInLengthPrefixCoder() throws Exception {
67+
public void testLengthPrefixingOfKeyCoderInStatefulExecutableStage() throws Exception {
6368
// Add another stateful stage with a non-standard key coder
6469
Pipeline p = Pipeline.create();
6570
Coder<Void> keycoder = VoidCoder.of();
@@ -82,16 +87,18 @@ public void process(ProcessContext ctxt) {}
8287
private final StateSpec<BagState<String>> bufferState =
8388
StateSpecs.bag(StringUtf8Coder.of());
8489

90+
@TimerId("timerId")
91+
private final TimerSpec timerSpec = TimerSpecs.timer(TimeDomain.EVENT_TIME);
92+
8593
@ProcessElement
8694
public void processElement(
8795
@Element KV<Void, String> element,
8896
@StateId("stateId") BagState<String> state,
89-
OutputReceiver<KV<Void, String>> r) {
90-
for (String value : state.read()) {
91-
r.output(KV.of(element.getKey(), value));
92-
}
93-
state.add(element.getValue());
94-
}
97+
@TimerId("timerId") Timer timer,
98+
OutputReceiver<KV<Void, String>> r) {}
99+
100+
@OnTimer("timerId")
101+
public void onTimer() {}
95102
}))
96103
// Force the output to be materialized
97104
.apply("gbk", GroupByKey.create());
@@ -111,34 +118,46 @@ public void processElement(
111118

112119
// Ensure original key coder is not a LengthPrefixCoder
113120
Map<String, RunnerApi.Coder> stageCoderMap = stage.getComponents().getCodersMap();
114-
RunnerApi.Coder originalCoder =
121+
RunnerApi.Coder originalMainInputCoder =
115122
stageCoderMap.get(inputPCollection.getPCollection().getCoderId());
116-
String originalKeyCoderId = ModelCoders.getKvCoderComponents(originalCoder).keyCoderId();
117-
assertThat(
118-
stageCoderMap.get(originalKeyCoderId).getSpec().getUrn(),
119-
is(CoderTranslation.JAVA_SERIALIZED_CODER_URN));
123+
String originalKeyCoderId =
124+
ModelCoders.getKvCoderComponents(originalMainInputCoder).keyCoderId();
125+
RunnerApi.Coder originalKeyCoder = stageCoderMap.get(originalKeyCoderId);
126+
assertThat(originalKeyCoder.getSpec().getUrn(), is(CoderTranslation.JAVA_SERIALIZED_CODER_URN));
120127

121128
// Now create ProcessBundleDescriptor and check for the LengthPrefixCoder around the key coder
122-
BeamFnApi.ProcessBundleDescriptor pbDescriptor =
129+
BeamFnApi.ProcessBundleDescriptor pbd =
123130
ProcessBundleDescriptors.fromExecutableStage(
124131
"test_stage", stage, Endpoints.ApiServiceDescriptor.getDefaultInstance())
125132
.getProcessBundleDescriptor();
133+
Map<String, RunnerApi.Coder> pbsCoderMap = pbd.getCodersMap();
126134

127-
String inputPCollectionId = inputPCollection.getId();
128-
String inputCoderId = pbDescriptor.getPcollectionsMap().get(inputPCollectionId).getCoderId();
135+
RunnerApi.Coder pbsMainInputCoder =
136+
pbsCoderMap.get(pbd.getPcollectionsOrThrow(inputPCollection.getId()).getCoderId());
137+
String keyCoderId = ModelCoders.getKvCoderComponents(pbsMainInputCoder).keyCoderId();
138+
RunnerApi.Coder keyCoder = pbsCoderMap.get(keyCoderId);
139+
ensureLengthPrefixed(keyCoder, originalKeyCoder, pbsCoderMap);
129140

130-
Map<String, RunnerApi.Coder> pbCoderMap = pbDescriptor.getCodersMap();
131-
RunnerApi.Coder coder = pbCoderMap.get(inputCoderId);
132-
String keyCoderId = ModelCoders.getKvCoderComponents(coder).keyCoderId();
133-
134-
RunnerApi.Coder keyCoder = pbCoderMap.get(keyCoderId);
135-
// Ensure length prefix
136-
assertThat(keyCoder.getSpec().getUrn(), is(ModelCoders.LENGTH_PREFIX_CODER_URN));
137-
String lengthPrefixWrappedCoderId = keyCoder.getComponentCoderIds(0);
141+
TimerReference timerRef = Iterables.getOnlyElement(stage.getTimers());
142+
String timerTransformId = timerRef.transform().getId();
143+
RunnerApi.ParDoPayload parDoPayload =
144+
RunnerApi.ParDoPayload.parseFrom(
145+
pbd.getTransformsOrThrow(timerTransformId).getSpec().getPayload());
146+
RunnerApi.TimerFamilySpec timerSpec =
147+
parDoPayload.getTimerFamilySpecsOrThrow(timerRef.localName());
148+
RunnerApi.Coder timerCoder = pbsCoderMap.get(timerSpec.getTimerFamilyCoderId());
149+
String timerKeyCoderId = timerCoder.getComponentCoderIds(0);
150+
RunnerApi.Coder timerKeyCoder = pbsCoderMap.get(timerKeyCoderId);
151+
ensureLengthPrefixed(timerKeyCoder, originalKeyCoder, pbsCoderMap);
152+
}
138153

154+
private static void ensureLengthPrefixed(
155+
RunnerApi.Coder coder,
156+
RunnerApi.Coder originalCoder,
157+
Map<String, RunnerApi.Coder> pbsCoderMap) {
158+
assertThat(coder.getSpec().getUrn(), is(ModelCoders.LENGTH_PREFIX_CODER_URN));
139159
// Check that the wrapped coder is unchanged
140-
assertThat(lengthPrefixWrappedCoderId, is(originalKeyCoderId));
141-
assertThat(
142-
pbCoderMap.get(lengthPrefixWrappedCoderId), is(stageCoderMap.get(originalKeyCoderId)));
160+
String lengthPrefixedWrappedCoderId = coder.getComponentCoderIds(0);
161+
assertThat(pbsCoderMap.get(lengthPrefixedWrappedCoderId), is(originalCoder));
143162
}
144163
}

sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@
3333
import unittest
3434
import uuid
3535
from builtins import range
36+
from typing import Any
3637
from typing import Dict
38+
from typing import Tuple
3739

3840
# patches unittest.TestCase to be python3 compatible
3941
import hamcrest # pylint: disable=ungrouped-imports
@@ -397,15 +399,24 @@ def process_clear_timer(self):
397399
assert_that(actual, equal_to(expected))
398400

399401
def test_pardo_state_timers(self):
400-
self._run_pardo_state_timers(False)
402+
self._run_pardo_state_timers(windowed=False)
401403

402-
def test_windowed_pardo_state_timers(self):
403-
self._run_pardo_state_timers(True)
404+
def test_pardo_state_timers_non_standard_coder(self):
405+
self._run_pardo_state_timers(windowed=False, key_type=Any)
404406

405-
def _run_pardo_state_timers(self, windowed):
407+
def test_windowed_pardo_state_timers(self):
408+
self._run_pardo_state_timers(windowed=True)
409+
410+
def _run_pardo_state_timers(self, windowed, key_type=None):
411+
"""
412+
:param windowed: If True, uses an interval window, otherwise a global window
413+
:param key_type: Allows to override the inferred key type. This is useful to
414+
test the use of non-standard coders, e.g. Python's FastPrimitivesCoder.
415+
"""
406416
state_spec = userstate.BagStateSpec('state', beam.coders.StrUtf8Coder())
407417
timer_spec = userstate.TimerSpec('timer', userstate.TimeDomain.WATERMARK)
408418
elements = list('abcdefgh')
419+
key = 'key'
409420
buffer_size = 3
410421

411422
class BufferDoFn(beam.DoFn):
@@ -456,7 +467,8 @@ def is_buffered_correctly(actual):
456467
| beam.Map(lambda e: window.TimestampedValue(e, ord(e) % 2))
457468
| beam.WindowInto(
458469
window.FixedWindows(1) if windowed else window.GlobalWindows())
459-
| beam.Map(lambda x: ('key', x))
470+
| beam.Map(lambda x: (key, x)).with_output_types(
471+
Tuple[key_type if key_type else type(key), Any])
460472
| beam.ParDo(BufferDoFn()))
461473

462474
assert_that(actual, is_buffered_correctly)

0 commit comments

Comments
 (0)