diff --git a/cadence/_internal/workflow/statemachine/child_workflow_execution_state_machine.py b/cadence/_internal/workflow/statemachine/child_workflow_execution_state_machine.py new file mode 100644 index 0000000..6760d22 --- /dev/null +++ b/cadence/_internal/workflow/statemachine/child_workflow_execution_state_machine.py @@ -0,0 +1,188 @@ +from __future__ import annotations + +from cadence._internal.workflow.statemachine.decision_state_machine import ( + BaseDecisionStateMachine, + DecisionFuture, + DecisionId, + DecisionState, + DecisionType, +) +from cadence._internal.workflow.statemachine.event_dispatcher import EventDispatcher +from cadence._internal.workflow.statemachine.nondeterminism import ( + record_immediate_cancel, +) +from cadence.api.v1 import decision, history +from cadence.api.v1.common_pb2 import Payload, WorkflowExecution +from cadence.error import ( + ChildWorkflowExecutionCanceled, + ChildWorkflowExecutionFailed, + ChildWorkflowExecutionTerminated, + ChildWorkflowExecutionTimedOut, + StartChildWorkflowExecutionFailed, +) + +# Default id_attr is "initiated_event_id" because the majority of child workflow events +# reference the state machine by the event ID of the InitiatedEvent. handle_initiated is +# the exception — it uses "workflow_id" directly and then registers the event ID as an alias. +child_workflow_events = EventDispatcher("initiated_event_id") + + +class ChildWorkflowExecutionStateMachine(BaseDecisionStateMachine): + """State machine for StartChildWorkflowExecution and child close events.""" + + request: decision.StartChildWorkflowExecutionDecisionAttributes + execution: DecisionFuture[WorkflowExecution] + result: DecisionFuture[Payload] + _run_id: str | None + + def __init__( + self, + request: decision.StartChildWorkflowExecutionDecisionAttributes, + execution: DecisionFuture[WorkflowExecution], + result: DecisionFuture[Payload], + ) -> None: + super().__init__() + self.request = request + self.execution = execution + self.result = result + self._run_id = None + + def get_id(self) -> DecisionId: + return DecisionId(DecisionType.CHILD_WORKFLOW, self.request.workflow_id) + + def get_decision(self) -> decision.Decision | None: + if self.state is DecisionState.REQUESTED: + return decision.Decision( + start_child_workflow_execution_decision_attributes=self.request + ) + if self.state is DecisionState.CANCELED_AFTER_REQUESTED: + return record_immediate_cancel(self.request) + if self.state in ( + DecisionState.CANCELED_AFTER_RECORDED, + DecisionState.CANCELED_AFTER_STARTED, + ): + return decision.Decision( + request_cancel_external_workflow_execution_decision_attributes=decision.RequestCancelExternalWorkflowExecutionDecisionAttributes( + domain=self.request.domain, + workflow_execution=WorkflowExecution( + workflow_id=self.request.workflow_id, + run_id=self._run_id or "", + ), + child_workflow_only=True, + ) + ) + return None + + def request_cancel(self) -> bool: + if self.state is DecisionState.REQUESTED: + self._transition(DecisionState.CANCELED_AFTER_REQUESTED) + self.execution.force_cancel() + self.result.force_cancel() + return True + + if self.state is DecisionState.RECORDED: + self._transition(DecisionState.CANCELED_AFTER_RECORDED) + self.execution.force_cancel() + return True + + if self.state is DecisionState.STARTED: + # We have a run_id at this point; use CANCELED_AFTER_STARTED so the + # cancel decision includes it, which avoids a potential race where we + # try to cancel before the server has finished processing the start. + self._transition(DecisionState.CANCELED_AFTER_STARTED) + return True + + return False + + @child_workflow_events.event("workflow_id", event_id_is_alias=True) + def handle_initiated( + self, _: history.StartChildWorkflowExecutionInitiatedEventAttributes + ) -> None: + self._transition(DecisionState.RECORDED) + + @child_workflow_events.event() + def handle_initiation_failed( + self, event: history.StartChildWorkflowExecutionFailedEventAttributes + ) -> None: + self._transition(DecisionState.COMPLETED) + exc = StartChildWorkflowExecutionFailed( + f"start child failed: {event.cause}", + cause=event.cause, + workflow_id=event.workflow_id, + ) + self.execution.set_exception(exc) + self.result.set_exception(exc) + + @child_workflow_events.event() + def handle_started( + self, event: history.ChildWorkflowExecutionStartedEventAttributes + ) -> None: + self._transition(DecisionState.STARTED) + self._run_id = event.workflow_execution.run_id + self.execution.set_result(event.workflow_execution) + + @child_workflow_events.event() + def handle_completed( + self, event: history.ChildWorkflowExecutionCompletedEventAttributes + ) -> None: + self._transition(DecisionState.COMPLETED) + self.result.set_result(event.result) + + @child_workflow_events.event() + def handle_failed( + self, event: history.ChildWorkflowExecutionFailedEventAttributes + ) -> None: + self._transition(DecisionState.COMPLETED) + self.result.set_exception( + ChildWorkflowExecutionFailed( + event.failure.reason, + failure=event.failure, + ) + ) + + @child_workflow_events.event() + def handle_canceled( + self, event: history.ChildWorkflowExecutionCanceledEventAttributes + ) -> None: + self._transition(DecisionState.COMPLETED) + self.result.set_exception( + ChildWorkflowExecutionCanceled( + "child workflow canceled", details=event.details + ) + ) + + @child_workflow_events.event() + def handle_timed_out( + self, event: history.ChildWorkflowExecutionTimedOutEventAttributes + ) -> None: + self._transition(DecisionState.COMPLETED) + self.result.set_exception( + ChildWorkflowExecutionTimedOut( + f"child workflow timed out: {event.timeout_type}", + timeout_type=int(event.timeout_type), + ) + ) + + @child_workflow_events.event() + def handle_terminated( + self, event: history.ChildWorkflowExecutionTerminatedEventAttributes + ) -> None: + self._transition(DecisionState.COMPLETED) + self.result.set_exception(ChildWorkflowExecutionTerminated()) + + # RequestCancelExternalWorkflowExecution events reference the child workflow by + # workflow_execution.workflow_id (a nested field), not by a bare string id. + # The dispatcher resolves dotted paths, so "workflow_execution.workflow_id" extracts + # the correct key for the alias lookup. event_id_is_alias=True registers this event's + # ID so that the subsequent handle_cancel_failed can look it up via initiated_event_id. + @child_workflow_events.event("workflow_execution.workflow_id", event_id_is_alias=True) + def handle_cancel_initiated( + self, _: history.RequestCancelExternalWorkflowExecutionInitiatedEventAttributes + ) -> None: + self._transition(DecisionState.CANCELLATION_RECORDED) + + @child_workflow_events.event() + def handle_cancel_failed( + self, _: history.RequestCancelExternalWorkflowExecutionFailedEventAttributes + ) -> None: + self._transition(DecisionState.STARTED) diff --git a/cadence/_internal/workflow/statemachine/decision_manager.py b/cadence/_internal/workflow/statemachine/decision_manager.py index 0a946c8..2bb1906 100644 --- a/cadence/_internal/workflow/statemachine/decision_manager.py +++ b/cadence/_internal/workflow/statemachine/decision_manager.py @@ -8,6 +8,10 @@ activity_events, ActivityStateMachine, ) +from cadence._internal.workflow.statemachine.child_workflow_execution_state_machine import ( + child_workflow_events, + ChildWorkflowExecutionStateMachine, +) from cadence._internal.workflow.statemachine.completion_state_machine import ( CompletionStateMachine, ) @@ -21,6 +25,7 @@ from cadence._internal.workflow.statemachine.event_dispatcher import ( EventDispatcher, Action, + resolve_id_attr, ) from cadence._internal.workflow.statemachine.nondeterminism import DeterminismTracker from cadence._internal.workflow.statemachine.timer_state_machine import ( @@ -28,7 +33,7 @@ timer_events, ) from cadence.api.v1 import decision, history -from cadence.api.v1.common_pb2 import Payload +from cadence.api.v1.common_pb2 import Payload, WorkflowExecution DecisionAlias = Tuple[DecisionType, str | int] @@ -67,6 +72,7 @@ class DecisionManager: { DecisionType.ACTIVITY: activity_events, DecisionType.TIMER: timer_events, + DecisionType.CHILD_WORKFLOW: child_workflow_events, } ) @@ -110,6 +116,22 @@ def start_timer( return future + # ----- Child Workflow API ----- + + def schedule_child_workflow( + self, attrs: decision.StartChildWorkflowExecutionDecisionAttributes + ) -> tuple[asyncio.Future[WorkflowExecution], asyncio.Future[Payload]]: + if self._replaying: + self._determinism_tracker.validate_action(attrs) + decision_id = DecisionId(DecisionType.CHILD_WORKFLOW, attrs.workflow_id) + execution: DecisionFuture[WorkflowExecution] = self._create_future(decision_id) + result: DecisionFuture[Payload] = DecisionFuture( + self._event_loop, lambda: self._request_cancel(decision_id) + ) + machine = ChildWorkflowExecutionStateMachine(attrs, execution, result) + self._add_state_machine(machine) + return execution, result + # ----- Workflow API ----- def complete_workflow(self, decision: decision.Decision) -> None: if self._replaying: @@ -152,8 +174,9 @@ def handle_history_event(self, event: history.HistoryEvent) -> None: decision_type = event_action.decision_type action = event_action.action # Find what state machine the event references. - # This may be a reference via the user id or a reference to a previous event - id_for_event = getattr(event_attributes, action.id_attr) + # This may be a reference via the user id or a reference to a previous event. + # Supports dotted paths (e.g. "workflow_execution.workflow_id") for nested fields. + id_for_event = resolve_id_attr(event_attributes, action.id_attr) alias = (decision_type, id_for_event) machine = self.aliases.get(alias, None) if machine is None: diff --git a/cadence/_internal/workflow/statemachine/event_dispatcher.py b/cadence/_internal/workflow/statemachine/event_dispatcher.py index 94d2977..7abd542 100644 --- a/cadence/_internal/workflow/statemachine/event_dispatcher.py +++ b/cadence/_internal/workflow/statemachine/event_dispatcher.py @@ -43,6 +43,17 @@ def decorator(func: EventHandler) -> EventHandler: return decorator +def resolve_id_attr(obj: Any, path: str) -> Any: + """Resolve a potentially dotted attribute path from a proto message. + + For example, resolve_id_attr(attrs, "workflow_execution.workflow_id") will + return attrs.workflow_execution.workflow_id. + """ + for part in path.split("."): + obj = getattr(obj, part) + return obj + + def _find_event_type(func: EventHandler) -> Type[Message]: sig = signature(func) type_hints = get_type_hints(func) @@ -69,8 +80,18 @@ def _find_event_type(func: EventHandler) -> Type[Message]: def _validate_field(func: EventHandler, event_type: Type[Message], field: str) -> None: - fields = event_type.DESCRIPTOR.fields_by_name - if field not in fields: - raise ValueError( - f"{func.__qualname__} handles {event_type.__qualname__}, which has no field {field}" - ) + """Validate that all parts of a (potentially dotted) field path exist on the proto type.""" + descriptor = event_type.DESCRIPTOR + parts = field.split(".") + for i, part in enumerate(parts): + fields = descriptor.fields_by_name + if part not in fields: + raise ValueError( + f"{func.__qualname__} handles {event_type.__qualname__}, which has no field {part!r} (in path {field!r})" + ) + if i < len(parts) - 1: + descriptor = fields[part].message_type + if descriptor is None: + raise ValueError( + f"{func.__qualname__}: field {part!r} is not a message type, cannot access sub-field in path {field!r}" + ) diff --git a/cadence/_internal/workflow/statemachine/nondeterminism.py b/cadence/_internal/workflow/statemachine/nondeterminism.py index 563ec11..d0f3046 100644 --- a/cadence/_internal/workflow/statemachine/nondeterminism.py +++ b/cadence/_internal/workflow/statemachine/nondeterminism.py @@ -267,6 +267,58 @@ def _(attrs: history.ActivityTaskCancelRequestedEventAttributes) -> Expectation: return Expectation(DecisionId(DecisionType.ACTIVITY, attrs.activity_id), CANCEL) +@to_expectation.register +def _(attrs: decision.StartChildWorkflowExecutionDecisionAttributes) -> Expectation: + return Expectation( + DecisionId(DecisionType.CHILD_WORKFLOW, attrs.workflow_id), + {"workflow_type": attrs.workflow_type.name}, + ) + + +@to_expectation.register +def _( + attrs: history.StartChildWorkflowExecutionInitiatedEventAttributes, +) -> Expectation: + return Expectation( + DecisionId(DecisionType.CHILD_WORKFLOW, attrs.workflow_id), + {"workflow_type": attrs.workflow_type.name}, + ) + + +@to_expectation.register +def _(attrs: history.StartChildWorkflowExecutionFailedEventAttributes) -> Expectation: + # A start failure is not a cancellation; enforce the same workflow_type as the initiation + # so that a mismatch between replay history and new code is caught. + return Expectation( + DecisionId(DecisionType.CHILD_WORKFLOW, attrs.workflow_id), + {"workflow_type": attrs.workflow_type.name}, + ) + + +@to_expectation.register +def _( + attrs: decision.RequestCancelExternalWorkflowExecutionDecisionAttributes, +) -> Expectation: + return Expectation( + DecisionId( + DecisionType.CHILD_WORKFLOW, attrs.workflow_execution.workflow_id + ), + CANCEL, + ) + + +@to_expectation.register +def _( + attrs: history.RequestCancelExternalWorkflowExecutionInitiatedEventAttributes, +) -> Expectation: + return Expectation( + DecisionId( + DecisionType.CHILD_WORKFLOW, attrs.workflow_execution.workflow_id + ), + CANCEL, + ) + + # Workflow Completion - Enforce complete vs failure. Maybe we should enforce the output data? @to_expectation.register def _(_: decision.CompleteWorkflowExecutionDecisionAttributes) -> Expectation: diff --git a/cadence/error.py b/cadence/error.py index af4c875..1034edb 100644 --- a/cadence/error.py +++ b/cadence/error.py @@ -129,3 +129,42 @@ class ServiceBusyError(CadenceRpcError): def __init__(self, message: str, code: grpc.StatusCode, reason: str) -> None: super().__init__(message, code, reason) self.reason = reason + + +class ChildWorkflowError(Exception): + """Base class for all child workflow lifecycle errors. + + Callers can catch this to handle any child-workflow-related failure uniformly, + or catch subclasses for more specific handling. + """ + + pass + + +class StartChildWorkflowExecutionFailed(ChildWorkflowError): + def __init__(self, message: str, cause: Any, workflow_id: str) -> None: + super().__init__(message) + self.cause = cause + self.workflow_id = workflow_id + + +class ChildWorkflowExecutionFailed(ChildWorkflowError): + def __init__(self, message: str, failure: Any) -> None: + super().__init__(message) + self.failure = failure + + +class ChildWorkflowExecutionCanceled(ChildWorkflowError): + def __init__(self, message: str, details: Any) -> None: + super().__init__(message) + self.details = details + + +class ChildWorkflowExecutionTimedOut(ChildWorkflowError): + def __init__(self, message: str, timeout_type: int) -> None: + super().__init__(message) + self.timeout_type = timeout_type + + +class ChildWorkflowExecutionTerminated(ChildWorkflowError): + pass diff --git a/tests/cadence/_internal/workflow/statemachine/test_child_workflow_execution_state_machine.py b/tests/cadence/_internal/workflow/statemachine/test_child_workflow_execution_state_machine.py new file mode 100644 index 0000000..9924042 --- /dev/null +++ b/tests/cadence/_internal/workflow/statemachine/test_child_workflow_execution_state_machine.py @@ -0,0 +1,307 @@ +import pytest + +from cadence._internal.workflow.statemachine.child_workflow_execution_state_machine import ( + ChildWorkflowExecutionStateMachine, +) +from cadence._internal.workflow.statemachine.decision_state_machine import ( + DecisionFuture, + DecisionState, +) +from cadence._internal.workflow.statemachine.nondeterminism import ( + record_immediate_cancel, +) +from cadence.api.v1 import decision, history +from cadence.api.v1.common_pb2 import Failure, Payload, WorkflowExecution, WorkflowType +from cadence.error import ( + ChildWorkflowExecutionCanceled, + ChildWorkflowExecutionFailed, + ChildWorkflowExecutionTerminated, + ChildWorkflowExecutionTimedOut, + StartChildWorkflowExecutionFailed, +) + +### These tests have to be async because they rely on the presence of an eventloop + +WF_ID = "child-wf-1" +DOMAIN = "test-domain" + + +def make_sm() -> tuple[ + ChildWorkflowExecutionStateMachine, + DecisionFuture[WorkflowExecution], + DecisionFuture[Payload], +]: + attrs = decision.StartChildWorkflowExecutionDecisionAttributes( + domain=DOMAIN, + workflow_id=WF_ID, + workflow_type=WorkflowType(name="MyWorkflow"), + ) + execution: DecisionFuture[WorkflowExecution] = DecisionFuture() + result: DecisionFuture[Payload] = DecisionFuture() + sm = ChildWorkflowExecutionStateMachine(attrs, execution, result) + return sm, execution, result + + +async def test_initial_state(): + sm, execution, result = make_sm() + + assert sm.state is DecisionState.REQUESTED + assert execution.done() is False + assert result.done() is False + assert sm.get_decision() == decision.Decision( + start_child_workflow_execution_decision_attributes=sm.request + ) + + +async def test_cancel_before_initiated(): + sm, execution, result = make_sm() + + assert sm.request_cancel() is True + + assert sm.state is DecisionState.CANCELED_AFTER_REQUESTED + assert execution.done() is True + assert execution.cancelled() is True + assert result.done() is True + assert result.cancelled() is True + assert sm.get_decision() == record_immediate_cancel(sm.request) + + +async def test_initiated_transitions_to_recorded(): + sm, execution, result = make_sm() + + sm.handle_initiated(history.StartChildWorkflowExecutionInitiatedEventAttributes()) + + assert sm.state is DecisionState.RECORDED + assert execution.done() is False + assert result.done() is False + assert sm.get_decision() is None + + +async def test_cancel_after_recorded(): + sm, execution, result = make_sm() + sm.handle_initiated(history.StartChildWorkflowExecutionInitiatedEventAttributes()) + + assert sm.request_cancel() is True + + assert sm.state is DecisionState.CANCELED_AFTER_RECORDED + assert execution.done() is True + assert execution.cancelled() is True + assert result.done() is False + + +async def test_cancel_after_started(): + sm, execution, result = make_sm() + sm.handle_initiated(history.StartChildWorkflowExecutionInitiatedEventAttributes()) + sm.handle_started( + history.ChildWorkflowExecutionStartedEventAttributes( + workflow_execution=WorkflowExecution(workflow_id=WF_ID, run_id="run-1") + ) + ) + + assert sm.request_cancel() is True + + # Once the child has started we know the run_id, so use CANCELED_AFTER_STARTED + # (not CANCELED_AFTER_RECORDED) so the cancel decision includes it. + assert sm.state is DecisionState.CANCELED_AFTER_STARTED + assert execution.done() is True + assert result.done() is False + + +async def test_cancel_after_started_includes_run_id(): + sm, execution, result = make_sm() + sm.handle_initiated(history.StartChildWorkflowExecutionInitiatedEventAttributes()) + sm.handle_started( + history.ChildWorkflowExecutionStartedEventAttributes( + workflow_execution=WorkflowExecution(workflow_id=WF_ID, run_id="run-42") + ) + ) + sm.request_cancel() + + cancel_decision = sm.get_decision() + assert cancel_decision is not None + attrs = cancel_decision.request_cancel_external_workflow_execution_decision_attributes + assert attrs.workflow_execution.run_id == "run-42" + assert attrs.workflow_execution.workflow_id == WF_ID + + +async def test_cancel_returns_false_when_completed(): + sm, execution, result = make_sm() + sm.handle_initiated(history.StartChildWorkflowExecutionInitiatedEventAttributes()) + sm.handle_started( + history.ChildWorkflowExecutionStartedEventAttributes( + workflow_execution=WorkflowExecution(workflow_id=WF_ID, run_id="run-1") + ) + ) + sm.handle_completed( + history.ChildWorkflowExecutionCompletedEventAttributes( + result=Payload(data=b"done") + ) + ) + + assert sm.request_cancel() is False + + +async def test_handle_initiation_failed_reuses_same_exception(): + sm, execution, result = make_sm() + sm.handle_initiated(history.StartChildWorkflowExecutionInitiatedEventAttributes()) + + sm.handle_initiation_failed( + history.StartChildWorkflowExecutionFailedEventAttributes( + workflow_id=WF_ID, + ) + ) + + assert sm.state is DecisionState.COMPLETED + assert execution.done() is True + assert result.done() is True + with pytest.raises(StartChildWorkflowExecutionFailed, match="start child failed"): + execution.result() + with pytest.raises(StartChildWorkflowExecutionFailed, match="start child failed"): + result.result() + + # Both futures hold the identical exception object (not just equal ones). + exc_execution = execution.exception() + exc_result = result.exception() + assert exc_execution is exc_result + + +async def test_handle_started_resolves_execution_future(): + sm, execution, result = make_sm() + sm.handle_initiated(history.StartChildWorkflowExecutionInitiatedEventAttributes()) + + wf_exec = WorkflowExecution(workflow_id=WF_ID, run_id="run-42") + sm.handle_started( + history.ChildWorkflowExecutionStartedEventAttributes(workflow_execution=wf_exec) + ) + + assert sm.state is DecisionState.STARTED + assert execution.done() is True + assert execution.result() == wf_exec + assert result.done() is False + + +async def test_handle_completed(): + sm, execution, result = make_sm() + sm.handle_initiated(history.StartChildWorkflowExecutionInitiatedEventAttributes()) + sm.handle_started( + history.ChildWorkflowExecutionStartedEventAttributes( + workflow_execution=WorkflowExecution(workflow_id=WF_ID, run_id="run-1") + ) + ) + + payload = Payload(data=b"output") + sm.handle_completed( + history.ChildWorkflowExecutionCompletedEventAttributes(result=payload) + ) + + assert sm.state is DecisionState.COMPLETED + assert result.done() is True + assert result.result() == payload + assert sm.get_decision() is None + + +async def test_handle_failed(): + sm, execution, result = make_sm() + sm.handle_initiated(history.StartChildWorkflowExecutionInitiatedEventAttributes()) + sm.handle_started( + history.ChildWorkflowExecutionStartedEventAttributes( + workflow_execution=WorkflowExecution(workflow_id=WF_ID, run_id="run-1") + ) + ) + + sm.handle_failed( + history.ChildWorkflowExecutionFailedEventAttributes( + failure=Failure(reason="boom") + ) + ) + + assert sm.state is DecisionState.COMPLETED + with pytest.raises(ChildWorkflowExecutionFailed, match="boom"): + result.result() + + +async def test_handle_canceled(): + sm, execution, result = make_sm() + sm.handle_initiated(history.StartChildWorkflowExecutionInitiatedEventAttributes()) + sm.handle_started( + history.ChildWorkflowExecutionStartedEventAttributes( + workflow_execution=WorkflowExecution(workflow_id=WF_ID, run_id="run-1") + ) + ) + + sm.handle_canceled( + history.ChildWorkflowExecutionCanceledEventAttributes( + details=Payload(data=b"cancel-details") + ) + ) + + assert sm.state is DecisionState.COMPLETED + with pytest.raises(ChildWorkflowExecutionCanceled, match="child workflow canceled"): + result.result() + + +async def test_handle_timed_out(): + sm, execution, result = make_sm() + sm.handle_initiated(history.StartChildWorkflowExecutionInitiatedEventAttributes()) + sm.handle_started( + history.ChildWorkflowExecutionStartedEventAttributes( + workflow_execution=WorkflowExecution(workflow_id=WF_ID, run_id="run-1") + ) + ) + + sm.handle_timed_out(history.ChildWorkflowExecutionTimedOutEventAttributes()) + + assert sm.state is DecisionState.COMPLETED + with pytest.raises( + ChildWorkflowExecutionTimedOut, match="child workflow timed out" + ): + result.result() + + +async def test_handle_terminated(): + sm, execution, result = make_sm() + sm.handle_initiated(history.StartChildWorkflowExecutionInitiatedEventAttributes()) + sm.handle_started( + history.ChildWorkflowExecutionStartedEventAttributes( + workflow_execution=WorkflowExecution(workflow_id=WF_ID, run_id="run-1") + ) + ) + + sm.handle_terminated(history.ChildWorkflowExecutionTerminatedEventAttributes()) + + assert sm.state is DecisionState.COMPLETED + with pytest.raises(ChildWorkflowExecutionTerminated): + result.result() + + +async def test_cancel_initiated_transitions_to_cancellation_recorded(): + sm, execution, result = make_sm() + sm.handle_initiated(history.StartChildWorkflowExecutionInitiatedEventAttributes()) + sm.request_cancel() + + sm.handle_cancel_initiated( + history.RequestCancelExternalWorkflowExecutionInitiatedEventAttributes() + ) + + assert sm.state is DecisionState.CANCELLATION_RECORDED + + +async def test_cancel_failed_reverts_to_started(): + sm, execution, result = make_sm() + sm.handle_initiated(history.StartChildWorkflowExecutionInitiatedEventAttributes()) + sm.handle_started( + history.ChildWorkflowExecutionStartedEventAttributes( + workflow_execution=WorkflowExecution(workflow_id=WF_ID, run_id="run-1") + ) + ) + sm.request_cancel() + sm.handle_cancel_initiated( + history.RequestCancelExternalWorkflowExecutionInitiatedEventAttributes() + ) + + sm.handle_cancel_failed( + history.RequestCancelExternalWorkflowExecutionFailedEventAttributes() + ) + + assert sm.state is DecisionState.STARTED + assert result.done() is False diff --git a/tests/cadence/_internal/workflow/statemachine/test_decision_manager.py b/tests/cadence/_internal/workflow/statemachine/test_decision_manager.py index f6bde2a..90b0c46 100644 --- a/tests/cadence/_internal/workflow/statemachine/test_decision_manager.py +++ b/tests/cadence/_internal/workflow/statemachine/test_decision_manager.py @@ -5,7 +5,8 @@ from cadence._internal.workflow.statemachine.decision_manager import DecisionManager from cadence.api.v1 import history, decision -from cadence.api.v1.common_pb2 import Payload +from cadence.api.v1.common_pb2 import Payload, WorkflowExecution, WorkflowType +from cadence.error import ChildWorkflowError, StartChildWorkflowExecutionFailed async def test_activity_dispatch(): @@ -161,6 +162,156 @@ async def test_collection_decisions_reordering(): assert activity2.done() is False +async def test_child_workflow_dispatch(): + decisions = DecisionManager(asyncio.get_event_loop()) + + execution, result = decisions.schedule_child_workflow( + decision.StartChildWorkflowExecutionDecisionAttributes( + workflow_id="child-wf-1", + workflow_type=WorkflowType(name="MyWorkflow"), + ) + ) + + decisions.handle_history_event(child_wf_initiated(1, "child-wf-1", initiated_event_id=1)) + decisions.handle_history_event(child_wf_started(2, started_event_id=1, wf_id="child-wf-1", run_id="run-1")) + decisions.handle_history_event(child_wf_completed(3, started_event_id=1, result=Payload(data=b"done"))) + + assert execution.done() is True + assert execution.result() == WorkflowExecution(workflow_id="child-wf-1", run_id="run-1") + assert result.done() is True + assert result.result() == Payload(data=b"done") + + +async def test_child_workflow_initiation_failed_dispatch(): + decisions = DecisionManager(asyncio.get_event_loop()) + + execution, result = decisions.schedule_child_workflow( + decision.StartChildWorkflowExecutionDecisionAttributes( + workflow_id="child-wf-1", + workflow_type=WorkflowType(name="MyWorkflow"), + ) + ) + + decisions.handle_history_event(child_wf_initiated(1, "child-wf-1", initiated_event_id=1)) + decisions.handle_history_event( + history.HistoryEvent( + event_id=2, + start_child_workflow_execution_failed_event_attributes=history.StartChildWorkflowExecutionFailedEventAttributes( + initiated_event_id=1, + workflow_id="child-wf-1", + workflow_type=WorkflowType(name="MyWorkflow"), + ), + ) + ) + + assert execution.done() is True + assert result.done() is True + with pytest.raises(StartChildWorkflowExecutionFailed): + result.result() + + +async def test_child_workflow_cancel_dispatch(): + decisions = DecisionManager(asyncio.get_event_loop()) + + execution, result = decisions.schedule_child_workflow( + decision.StartChildWorkflowExecutionDecisionAttributes( + workflow_id="child-wf-1", + workflow_type=WorkflowType(name="MyWorkflow"), + ) + ) + + decisions.handle_history_event(child_wf_initiated(1, "child-wf-1", initiated_event_id=1)) + decisions.handle_history_event(child_wf_started(2, started_event_id=1, wf_id="child-wf-1", run_id="run-1")) + + # Cancel the child workflow + result.cancel() + + # cancel_initiated and cancel_failed are dispatched by workflow_execution.workflow_id + decisions.handle_history_event( + history.HistoryEvent( + event_id=3, + request_cancel_external_workflow_execution_initiated_event_attributes=history.RequestCancelExternalWorkflowExecutionInitiatedEventAttributes( + workflow_execution=WorkflowExecution(workflow_id="child-wf-1", run_id="run-1"), + ), + ) + ) + decisions.handle_history_event( + history.HistoryEvent( + event_id=4, + request_cancel_external_workflow_execution_failed_event_attributes=history.RequestCancelExternalWorkflowExecutionFailedEventAttributes( + initiated_event_id=3, + ), + ) + ) + + # Cancel failed — back to STARTED + assert result.done() is False + + +async def test_child_workflow_errors_are_child_workflow_error(): + """All child workflow errors share a common base class.""" + decisions = DecisionManager(asyncio.get_event_loop()) + + _, result = decisions.schedule_child_workflow( + decision.StartChildWorkflowExecutionDecisionAttributes( + workflow_id="child-wf-1", + workflow_type=WorkflowType(name="MyWorkflow"), + ) + ) + + decisions.handle_history_event(child_wf_initiated(1, "child-wf-1", initiated_event_id=1)) + decisions.handle_history_event( + history.HistoryEvent( + event_id=2, + start_child_workflow_execution_failed_event_attributes=history.StartChildWorkflowExecutionFailedEventAttributes( + initiated_event_id=1, + workflow_id="child-wf-1", + workflow_type=WorkflowType(name="MyWorkflow"), + ), + ) + ) + + assert result.done() is True + with pytest.raises(ChildWorkflowError): + result.result() + + +def child_wf_initiated( + event_id: int, workflow_id: str, *, initiated_event_id: int +) -> history.HistoryEvent: + return history.HistoryEvent( + event_id=event_id, + start_child_workflow_execution_initiated_event_attributes=history.StartChildWorkflowExecutionInitiatedEventAttributes( + workflow_id=workflow_id, + workflow_type=WorkflowType(name="MyWorkflow"), + ), + ) + + +def child_wf_started( + event_id: int, *, started_event_id: int, wf_id: str, run_id: str +) -> history.HistoryEvent: + return history.HistoryEvent( + event_id=event_id, + child_workflow_execution_started_event_attributes=history.ChildWorkflowExecutionStartedEventAttributes( + initiated_event_id=started_event_id, + workflow_execution=WorkflowExecution(workflow_id=wf_id, run_id=run_id), + ), + ) + + +def child_wf_completed( + event_id: int, *, started_event_id: int, result: Payload +) -> history.HistoryEvent: + return history.HistoryEvent( + event_id=event_id, + child_workflow_execution_completed_event_attributes=history.ChildWorkflowExecutionCompletedEventAttributes( + initiated_event_id=started_event_id, + result=result, + ), + ) + + def activity_scheduled(event_id: int, activity_id: str) -> history.HistoryEvent: return history.HistoryEvent( event_id=event_id,