3434import org .apache .beam .runners .core .construction .graph .FusedPipeline ;
3535import org .apache .beam .runners .core .construction .graph .GreedyPipelineFuser ;
3636import org .apache .beam .runners .core .construction .graph .PipelineNode ;
37+ import org .apache .beam .runners .core .construction .graph .TimerReference ;
3738import org .apache .beam .sdk .Pipeline ;
3839import org .apache .beam .sdk .coders .Coder ;
3940import org .apache .beam .sdk .coders .KvCoder ;
4243import org .apache .beam .sdk .state .BagState ;
4344import org .apache .beam .sdk .state .StateSpec ;
4445import org .apache .beam .sdk .state .StateSpecs ;
46+ import org .apache .beam .sdk .state .TimeDomain ;
47+ import org .apache .beam .sdk .state .Timer ;
48+ import org .apache .beam .sdk .state .TimerSpec ;
49+ import org .apache .beam .sdk .state .TimerSpecs ;
4550import org .apache .beam .sdk .transforms .DoFn ;
4651import org .apache .beam .sdk .transforms .GroupByKey ;
4752import org .apache .beam .sdk .transforms .Impulse ;
@@ -59,7 +64,7 @@ public class ProcessBundleDescriptorsTest implements Serializable {
5964 * LengthPrefixCoder.
6065 */
6166 @ Test
62- public void testWrapKeyCoderOfStatefulExecutableStageInLengthPrefixCoder () throws Exception {
67+ public void testLengthPrefixingOfKeyCoderInStatefulExecutableStage () throws Exception {
6368 // Add another stateful stage with a non-standard key coder
6469 Pipeline p = Pipeline .create ();
6570 Coder <Void > keycoder = VoidCoder .of ();
@@ -82,16 +87,18 @@ public void process(ProcessContext ctxt) {}
8287 private final StateSpec <BagState <String >> bufferState =
8388 StateSpecs .bag (StringUtf8Coder .of ());
8489
90+ @ TimerId ("timerId" )
91+ private final TimerSpec timerSpec = TimerSpecs .timer (TimeDomain .EVENT_TIME );
92+
8593 @ ProcessElement
8694 public void processElement (
8795 @ Element KV <Void , String > element ,
8896 @ StateId ("stateId" ) BagState <String > state ,
89- OutputReceiver <KV <Void , String >> r ) {
90- for (String value : state .read ()) {
91- r .output (KV .of (element .getKey (), value ));
92- }
93- state .add (element .getValue ());
94- }
97+ @ TimerId ("timerId" ) Timer timer ,
98+ OutputReceiver <KV <Void , String >> r ) {}
99+
100+ @ OnTimer ("timerId" )
101+ public void onTimer () {}
95102 }))
96103 // Force the output to be materialized
97104 .apply ("gbk" , GroupByKey .create ());
@@ -111,34 +118,46 @@ public void processElement(
111118
112119 // Ensure original key coder is not a LengthPrefixCoder
113120 Map <String , RunnerApi .Coder > stageCoderMap = stage .getComponents ().getCodersMap ();
114- RunnerApi .Coder originalCoder =
121+ RunnerApi .Coder originalMainInputCoder =
115122 stageCoderMap .get (inputPCollection .getPCollection ().getCoderId ());
116- String originalKeyCoderId = ModelCoders . getKvCoderComponents ( originalCoder ). keyCoderId ();
117- assertThat (
118- stageCoderMap .get (originalKeyCoderId ). getSpec (). getUrn (),
119- is (CoderTranslation .JAVA_SERIALIZED_CODER_URN ));
123+ String originalKeyCoderId =
124+ ModelCoders . getKvCoderComponents ( originalMainInputCoder ). keyCoderId ();
125+ RunnerApi . Coder originalKeyCoder = stageCoderMap .get (originalKeyCoderId );
126+ assertThat ( originalKeyCoder . getSpec (). getUrn (), is (CoderTranslation .JAVA_SERIALIZED_CODER_URN ));
120127
121128 // Now create ProcessBundleDescriptor and check for the LengthPrefixCoder around the key coder
122- BeamFnApi .ProcessBundleDescriptor pbDescriptor =
129+ BeamFnApi .ProcessBundleDescriptor pbd =
123130 ProcessBundleDescriptors .fromExecutableStage (
124131 "test_stage" , stage , Endpoints .ApiServiceDescriptor .getDefaultInstance ())
125132 .getProcessBundleDescriptor ();
133+ Map <String , RunnerApi .Coder > pbsCoderMap = pbd .getCodersMap ();
126134
127- String inputPCollectionId = inputPCollection .getId ();
128- String inputCoderId = pbDescriptor .getPcollectionsMap ().get (inputPCollectionId ).getCoderId ();
135+ RunnerApi .Coder pbsMainInputCoder =
136+ pbsCoderMap .get (pbd .getPcollectionsOrThrow (inputPCollection .getId ()).getCoderId ());
137+ String keyCoderId = ModelCoders .getKvCoderComponents (pbsMainInputCoder ).keyCoderId ();
138+ RunnerApi .Coder keyCoder = pbsCoderMap .get (keyCoderId );
139+ ensureLengthPrefixed (keyCoder , originalKeyCoder , pbsCoderMap );
129140
130- Map <String , RunnerApi .Coder > pbCoderMap = pbDescriptor .getCodersMap ();
131- RunnerApi .Coder coder = pbCoderMap .get (inputCoderId );
132- String keyCoderId = ModelCoders .getKvCoderComponents (coder ).keyCoderId ();
133-
134- RunnerApi .Coder keyCoder = pbCoderMap .get (keyCoderId );
135- // Ensure length prefix
136- assertThat (keyCoder .getSpec ().getUrn (), is (ModelCoders .LENGTH_PREFIX_CODER_URN ));
137- String lengthPrefixWrappedCoderId = keyCoder .getComponentCoderIds (0 );
141+ TimerReference timerRef = Iterables .getOnlyElement (stage .getTimers ());
142+ String timerTransformId = timerRef .transform ().getId ();
143+ RunnerApi .ParDoPayload parDoPayload =
144+ RunnerApi .ParDoPayload .parseFrom (
145+ pbd .getTransformsOrThrow (timerTransformId ).getSpec ().getPayload ());
146+ RunnerApi .TimerFamilySpec timerSpec =
147+ parDoPayload .getTimerFamilySpecsOrThrow (timerRef .localName ());
148+ RunnerApi .Coder timerCoder = pbsCoderMap .get (timerSpec .getTimerFamilyCoderId ());
149+ String timerKeyCoderId = timerCoder .getComponentCoderIds (0 );
150+ RunnerApi .Coder timerKeyCoder = pbsCoderMap .get (timerKeyCoderId );
151+ ensureLengthPrefixed (timerKeyCoder , originalKeyCoder , pbsCoderMap );
152+ }
138153
154+ private static void ensureLengthPrefixed (
155+ RunnerApi .Coder coder ,
156+ RunnerApi .Coder originalCoder ,
157+ Map <String , RunnerApi .Coder > pbsCoderMap ) {
158+ assertThat (coder .getSpec ().getUrn (), is (ModelCoders .LENGTH_PREFIX_CODER_URN ));
139159 // Check that the wrapped coder is unchanged
140- assertThat (lengthPrefixWrappedCoderId , is (originalKeyCoderId ));
141- assertThat (
142- pbCoderMap .get (lengthPrefixWrappedCoderId ), is (stageCoderMap .get (originalKeyCoderId )));
160+ String lengthPrefixedWrappedCoderId = coder .getComponentCoderIds (0 );
161+ assertThat (pbsCoderMap .get (lengthPrefixedWrappedCoderId ), is (originalCoder ));
143162 }
144163}
0 commit comments