diff --git a/src/main/java/com/uber/cadence/internal/sync/SyncWorkflowWorker.java b/src/main/java/com/uber/cadence/internal/sync/SyncWorkflowWorker.java index 07cba4b0e..a66658c75 100644 --- a/src/main/java/com/uber/cadence/internal/sync/SyncWorkflowWorker.java +++ b/src/main/java/com/uber/cadence/internal/sync/SyncWorkflowWorker.java @@ -17,7 +17,6 @@ package com.uber.cadence.internal.sync; -import com.uber.cadence.PollForDecisionTaskResponse; import com.uber.cadence.common.WorkflowExecutionHistory; import com.uber.cadence.converter.DataConverter; import com.uber.cadence.internal.common.InternalUtils; @@ -37,15 +36,12 @@ import com.uber.cadence.workflow.Functions.Func; import com.uber.cadence.workflow.WorkflowInterceptor; import java.lang.reflect.Type; -import java.time.Duration; import java.util.Objects; import java.util.concurrent.*; -import java.util.function.Consumer; import java.util.function.Function; /** Workflow worker that supports POJO workflow implementations. */ -public class SyncWorkflowWorker - implements SuspendableWorker, Consumer { +public class SyncWorkflowWorker implements SuspendableWorker { private final WorkflowWorker workflowWorker; private final LocalActivityWorker laWorker; @@ -68,7 +64,6 @@ public SyncWorkflowWorker( SingleWorkerOptions locallyDispatchedActivityOptions, DeciderCache cache, String stickyTaskListName, - Duration stickyDecisionScheduleToStartTimeout, ThreadPoolExecutor workflowThreadPool) { Objects.requireNonNull(workflowThreadPool); this.dataConverter = workflowOptions.getDataConverter(); @@ -100,7 +95,7 @@ public SyncWorkflowWorker( cache, workflowOptions, stickyTaskListName, - stickyDecisionScheduleToStartTimeout, + workflowOptions.getStickyTaskListScheduleToStartTimeout(), service, laWorker.getLocalActivityTaskPoller()); @@ -241,11 +236,6 @@ public R queryWorkflowExecution( return dataConverter.fromData(result, resultClass, resultType); } - @Override - public void accept(PollForDecisionTaskResponse pollForDecisionTaskResponse) { - workflowWorker.accept(pollForDecisionTaskResponse); - } - public CompletableFuture isHealthy() { return service.isHealthy(); } diff --git a/src/main/java/com/uber/cadence/internal/testservice/TaskQueue.java b/src/main/java/com/uber/cadence/internal/testservice/TaskQueue.java new file mode 100644 index 000000000..9047ba14b --- /dev/null +++ b/src/main/java/com/uber/cadence/internal/testservice/TaskQueue.java @@ -0,0 +1,149 @@ +/* + * Copyright 2012-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Modifications copyright (C) 2017 Uber Technologies, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not + * use this file except in compliance with the License. A copy of the License is + * located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package com.uber.cadence.internal.testservice; + +import com.google.common.base.Preconditions; +import java.util.LinkedList; +import java.util.concurrent.CancellationException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import javax.annotation.Nonnull; + +/** + * A specialized unbounded queue that requires blocking poll operations to happen through a Future + * so that they can be cancelled (i.e. cancelling the future breaks out of the poll via a + * j.u.c.CancellationException). + * + * @param + */ +class TaskQueue { + private final LinkedList backlog = new LinkedList<>(); + private final LinkedList waiters = new LinkedList<>(); + + /** + * Adds the provided element to the tail of this queue. + * + * @param element the value to add + */ + synchronized void add(E element) { + for (PollFuture future = waiters.poll(); future != null; future = waiters.poll()) { + if (future.set(element)) { + return; + } + } + backlog.push(element); + } + + /** + * Creates a new j.u.c.Future whose get() method will eventually return a value from the head of + * this queue. Note that failing to call get() on the returned Future can result in missed queue + * updates. + * + * @return a Future providing one-shot access to the head of this queue. + */ + Future poll() { + final PollFuture future = new PollFuture(); + E element; + synchronized (this) { + if (backlog.isEmpty()) { + waiters.push(future); + return future; + } + element = backlog.pop(); + } + future.set(element); + return future; + } + + /** + * A Future implementation specifically for consuming from the enclosing TaskQueue type. The get + * method on this class blocks until a value is available from the queue but unlike + * BlockingQueue#take, a blocked consumer can be "interrupted" without the use of thread + * interruption by calling #cancel() on this Future. + */ + private class PollFuture implements Future { + boolean cancelled = false; + E value; + + private synchronized boolean set(E element) { + Preconditions.checkState(value == null); + if (cancelled) { + return false; + } + value = element; + notifyAll(); + return true; + } + + @Override + public boolean cancel(boolean ignored) { + synchronized (TaskQueue.this) { + TaskQueue.this.waiters.remove(this); + } + synchronized (this) { + if (value != null) { + return false; + } + cancelled = true; + notifyAll(); + return true; + } + } + + @Override + public synchronized boolean isCancelled() { + return cancelled; + } + + @Override + public synchronized boolean isDone() { + return value != null; + } + + @Override + public synchronized E get() throws InterruptedException, ExecutionException { + while (value == null && !cancelled) { + this.wait(); + } + if (cancelled) { + throw new CancellationException(); + } + return value; + } + + @Override + public synchronized E get(long timeout, @Nonnull TimeUnit unit) + throws InterruptedException, ExecutionException, TimeoutException { + long waitTimeNanos = unit.toNanos(timeout); + long deadline = System.nanoTime() + waitTimeNanos; + while (value == null && !cancelled) { + long remainingNanos = deadline - System.nanoTime(); + if (remainingNanos <= 0) { + throw new TimeoutException(); + } + TimeUnit.NANOSECONDS.timedWait(this, remainingNanos); + } + if (cancelled) { + throw new CancellationException(); + } + return value; + } + } +} diff --git a/src/main/java/com/uber/cadence/internal/testservice/TestWorkflowMutableStateImpl.java b/src/main/java/com/uber/cadence/internal/testservice/TestWorkflowMutableStateImpl.java index 1a3f3239e..ccd8284d3 100644 --- a/src/main/java/com/uber/cadence/internal/testservice/TestWorkflowMutableStateImpl.java +++ b/src/main/java/com/uber/cadence/internal/testservice/TestWorkflowMutableStateImpl.java @@ -101,6 +101,7 @@ void apply(RequestContext ctx) private final Map pendingQueries = new ConcurrentHashMap<>(); private final Optional continuedExecutionRunId; public StickyExecutionAttributes stickyExecutionAttributes; + private StickyExecutionAttributes previousStickyExecutionAttributes; /** * @param retryState present if workflow is a retry @@ -151,18 +152,31 @@ private void completeDecisionUpdate(UpdateProcedure updater, StickyExecutionAttr throws InternalServiceError, EntityNotExistsError, WorkflowExecutionAlreadyCompletedError, BadRequestError { StackTraceElement[] stackTraceElements = Thread.currentThread().getStackTrace(); - stickyExecutionAttributes = attributes; - update(true, updater, stackTraceElements[2].getMethodName()); + update(true, updater, stackTraceElements[2].getMethodName(), attributes); } private void update(boolean completeDecisionUpdate, UpdateProcedure updater, String caller) throws InternalServiceError, EntityNotExistsError, WorkflowExecutionAlreadyCompletedError, BadRequestError { + update(completeDecisionUpdate, updater, caller, null); + } + + private void update( + boolean completeDecisionUpdate, + UpdateProcedure updater, + String caller, + StickyExecutionAttributes attributes) + throws InternalServiceError, EntityNotExistsError, WorkflowExecutionAlreadyCompletedError, + BadRequestError { String callerInfo = "Decision Update from " + caller; lock.lock(); LockHandle lockHandle = selfAdvancingTimer.lockTimeSkipping(callerInfo); try { + if (completeDecisionUpdate) { + previousStickyExecutionAttributes = stickyExecutionAttributes; + stickyExecutionAttributes = attributes; + } checkCompleted(); boolean concurrentDecision = !completeDecisionUpdate @@ -618,6 +632,7 @@ private void timeoutDecisionTask(long scheduledEventId) { || decision.getData().scheduledEventId != scheduledEventId || decision.getState() == State.COMPLETED) { // timeout for a previous decision + this.stickyExecutionAttributes = previousStickyExecutionAttributes; return; } decision.action(StateMachines.Action.TIME_OUT, ctx, TimeoutType.START_TO_CLOSE, 0); diff --git a/src/main/java/com/uber/cadence/internal/testservice/TestWorkflowService.java b/src/main/java/com/uber/cadence/internal/testservice/TestWorkflowService.java index 7b0b94eac..39d517108 100644 --- a/src/main/java/com/uber/cadence/internal/testservice/TestWorkflowService.java +++ b/src/main/java/com/uber/cadence/internal/testservice/TestWorkflowService.java @@ -393,10 +393,29 @@ public GetWorkflowExecutionHistoryResponse GetWorkflowExecutionHistoryWithTimeou public PollForDecisionTaskResponse PollForDecisionTask(PollForDecisionTaskRequest pollRequest) throws BadRequestError, InternalServiceError, ServiceBusyError, CadenceError { PollForDecisionTaskResponse task; + java.util.concurrent.Future future = + store.pollForDecisionTask(pollRequest); try { - task = store.pollForDecisionTask(pollRequest); + // Poll with 60 second timeout to match production long poll behavior + task = future.get(60, java.util.concurrent.TimeUnit.SECONDS); } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + future.cancel(true); // Cancel the Future to remove it from waiters queue return new PollForDecisionTaskResponse(); + } catch (java.util.concurrent.TimeoutException e) { + // Poll timed out - cancel to remove from waiters queue + future.cancel(true); + return new PollForDecisionTaskResponse(); + } catch (java.util.concurrent.CancellationException e) { + // Poll was cancelled - return empty response + return new PollForDecisionTaskResponse(); + } catch (java.util.concurrent.ExecutionException e) { + future.cancel(true); + throw new InternalServiceError("Error polling for decision task: " + e.getMessage()); + } + // Return empty response if poll timed out + if (task.getWorkflowExecution() == null) { + return task; } ExecutionId executionId = new ExecutionId(pollRequest.getDomain(), task.getWorkflowExecution()); TestWorkflowMutableState mutableState = getMutableState(executionId); @@ -440,10 +459,29 @@ public PollForActivityTaskResponse PollForActivityTask(PollForActivityTaskReques throws BadRequestError, InternalServiceError, ServiceBusyError, CadenceError { PollForActivityTaskResponse task; while (true) { + java.util.concurrent.Future future = + store.pollForActivityTask(pollRequest); try { - task = store.pollForActivityTask(pollRequest); + // Poll with 60 second timeout to match production long poll behavior + task = future.get(60, java.util.concurrent.TimeUnit.SECONDS); } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + future.cancel(true); // Cancel the Future to remove it from waiters queue return new PollForActivityTaskResponse(); + } catch (java.util.concurrent.TimeoutException e) { + // Poll timed out - cancel to remove from waiters queue + future.cancel(true); + return new PollForActivityTaskResponse(); + } catch (java.util.concurrent.CancellationException e) { + // Poll was cancelled - return empty response + return new PollForActivityTaskResponse(); + } catch (java.util.concurrent.ExecutionException e) { + future.cancel(true); + throw new InternalServiceError("Error polling for activity task: " + e.getMessage()); + } + // Return empty response if poll timed out + if (task.getWorkflowExecution() == null) { + return task; } ExecutionId executionId = new ExecutionId(pollRequest.getDomain(), task.getWorkflowExecution()); diff --git a/src/main/java/com/uber/cadence/internal/testservice/TestWorkflowStore.java b/src/main/java/com/uber/cadence/internal/testservice/TestWorkflowStore.java index d5672617e..af5449f13 100644 --- a/src/main/java/com/uber/cadence/internal/testservice/TestWorkflowStore.java +++ b/src/main/java/com/uber/cadence/internal/testservice/TestWorkflowStore.java @@ -31,6 +31,7 @@ import java.util.List; import java.util.Objects; import java.util.Optional; +import java.util.concurrent.Future; interface TestWorkflowStore { @@ -144,11 +145,9 @@ long save(RequestContext requestContext) void registerDelayedCallback(Duration delay, Runnable r); - PollForDecisionTaskResponse pollForDecisionTask(PollForDecisionTaskRequest pollRequest) - throws InterruptedException; + Future pollForDecisionTask(PollForDecisionTaskRequest pollRequest); - PollForActivityTaskResponse pollForActivityTask(PollForActivityTaskRequest pollRequest) - throws InterruptedException; + Future pollForActivityTask(PollForActivityTaskRequest pollRequest); /** @return queryId */ void sendQueryTask(ExecutionId executionId, TaskListId taskList, PollForDecisionTaskResponse task) diff --git a/src/main/java/com/uber/cadence/internal/testservice/TestWorkflowStoreImpl.java b/src/main/java/com/uber/cadence/internal/testservice/TestWorkflowStoreImpl.java index 5ebfaa976..f60a231d7 100644 --- a/src/main/java/com/uber/cadence/internal/testservice/TestWorkflowStoreImpl.java +++ b/src/main/java/com/uber/cadence/internal/testservice/TestWorkflowStoreImpl.java @@ -42,8 +42,7 @@ import java.util.Map; import java.util.Map.Entry; import java.util.Optional; -import java.util.concurrent.BlockingQueue; -import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.Future; import java.util.concurrent.locks.Condition; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; @@ -143,10 +142,10 @@ List waitForNewEvents( private final Map histories = new HashMap<>(); - private final Map> activityTaskLists = + private final Map> activityTaskLists = new HashMap<>(); - private final Map> decisionTaskLists = + private final Map> decisionTaskLists = new HashMap<>(); private final SelfAdvancingTimer timerService = @@ -210,14 +209,14 @@ public long save(RequestContext ctx) ? decisionTask.getTaskListId().getTaskListName() : attributes.getWorkerTaskList().getName()); - BlockingQueue decisionsQueue = getDecisionTaskListQueue(id); + TaskQueue decisionsQueue = getDecisionTaskListQueue(id); decisionsQueue.add(decisionTask.getTask()); } List activityTasks = ctx.getActivityTasks(); if (activityTasks != null) { for (ActivityTask activityTask : activityTasks) { - BlockingQueue activitiesQueue = + TaskQueue activitiesQueue = getActivityTaskListQueue(activityTask.getTaskListId()); activitiesQueue.add(activityTask.getTask()); } @@ -258,31 +257,26 @@ public void registerDelayedCallback(Duration delay, Runnable r) { timerService.schedule(delay, r, "registerDelayedCallback"); } - private BlockingQueue getActivityTaskListQueue( - TaskListId taskListId) { + private TaskQueue getActivityTaskListQueue(TaskListId taskListId) { lock.lock(); try { - { - BlockingQueue activitiesQueue = - activityTaskLists.get(taskListId); - if (activitiesQueue == null) { - activitiesQueue = new LinkedBlockingQueue<>(); - activityTaskLists.put(taskListId, activitiesQueue); - } - return activitiesQueue; + TaskQueue activitiesQueue = activityTaskLists.get(taskListId); + if (activitiesQueue == null) { + activitiesQueue = new TaskQueue<>(); + activityTaskLists.put(taskListId, activitiesQueue); } + return activitiesQueue; } finally { lock.unlock(); } } - private BlockingQueue getDecisionTaskListQueue( - TaskListId taskListId) { + private TaskQueue getDecisionTaskListQueue(TaskListId taskListId) { lock.lock(); try { - BlockingQueue decisionsQueue = decisionTaskLists.get(taskListId); + TaskQueue decisionsQueue = decisionTaskLists.get(taskListId); if (decisionsQueue == null) { - decisionsQueue = new LinkedBlockingQueue<>(); + decisionsQueue = new TaskQueue<>(); decisionTaskLists.put(taskListId, decisionsQueue); } return decisionsQueue; @@ -292,23 +286,19 @@ private BlockingQueue getDecisionTaskListQueue( } @Override - public PollForDecisionTaskResponse pollForDecisionTask(PollForDecisionTaskRequest pollRequest) - throws InterruptedException { + public Future pollForDecisionTask( + PollForDecisionTaskRequest pollRequest) { TaskListId taskListId = new TaskListId(pollRequest.getDomain(), pollRequest.getTaskList().getName()); - BlockingQueue decisionsQueue = - getDecisionTaskListQueue(taskListId); - return decisionsQueue.take(); + return getDecisionTaskListQueue(taskListId).poll(); } @Override - public PollForActivityTaskResponse pollForActivityTask(PollForActivityTaskRequest pollRequest) - throws InterruptedException { + public Future pollForActivityTask( + PollForActivityTaskRequest pollRequest) { TaskListId taskListId = new TaskListId(pollRequest.getDomain(), pollRequest.getTaskList().getName()); - BlockingQueue activityTaskQueue = - getActivityTaskListQueue(taskListId); - return activityTaskQueue.take(); + return getActivityTaskListQueue(taskListId).poll(); } @Override @@ -329,7 +319,7 @@ public void sendQueryTask( } finally { lock.unlock(); } - BlockingQueue decisionsQueue = getDecisionTaskListQueue(taskList); + TaskQueue decisionsQueue = getDecisionTaskListQueue(taskList); decisionsQueue.add(task); } diff --git a/src/main/java/com/uber/cadence/internal/worker/DecisionTask.java b/src/main/java/com/uber/cadence/internal/worker/DecisionTask.java new file mode 100644 index 000000000..ac2e83824 --- /dev/null +++ b/src/main/java/com/uber/cadence/internal/worker/DecisionTask.java @@ -0,0 +1,40 @@ +/* + * Copyright 2012-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Modifications copyright (C) 2017 Uber Technologies, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not + * use this file except in compliance with the License. A copy of the License is + * located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package com.uber.cadence.internal.worker; + +import com.uber.cadence.PollForDecisionTaskResponse; +import com.uber.cadence.workflow.Functions; +import java.util.Objects; + +public final class DecisionTask { + private final PollForDecisionTaskResponse response; + private final Functions.Proc completionCallback; + + public DecisionTask(PollForDecisionTaskResponse response, Functions.Proc completionCallback) { + this.response = Objects.requireNonNull(response); + this.completionCallback = Objects.requireNonNull(completionCallback); + } + + public PollForDecisionTaskResponse getResponse() { + return response; + } + + public Functions.Proc getCompletionCallback() { + return completionCallback; + } +} diff --git a/src/main/java/com/uber/cadence/internal/worker/PollDecisionTaskDispatcher.java b/src/main/java/com/uber/cadence/internal/worker/PollDecisionTaskDispatcher.java deleted file mode 100644 index 258385876..000000000 --- a/src/main/java/com/uber/cadence/internal/worker/PollDecisionTaskDispatcher.java +++ /dev/null @@ -1,116 +0,0 @@ -/* - * Copyright 2012-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Modifications copyright (C) 2017 Uber Technologies, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not - * use this file except in compliance with the License. A copy of the License is - * located at - * - * http://aws.amazon.com/apache2.0 - * - * or in the "license" file accompanying this file. This file is distributed on - * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package com.uber.cadence.internal.worker; - -import com.uber.cadence.DecisionTaskFailedCause; -import com.uber.cadence.PollForDecisionTaskResponse; -import com.uber.cadence.RespondDecisionTaskFailedRequest; -import com.uber.cadence.serviceclient.IWorkflowService; -import java.nio.charset.Charset; -import java.util.Map; -import java.util.Objects; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.RejectedExecutionException; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.Consumer; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -public final class PollDecisionTaskDispatcher - implements ShutdownableTaskExecutor { - - private static final Logger log = LoggerFactory.getLogger(PollDecisionTaskDispatcher.class); - private final Map> subscribers = - new ConcurrentHashMap<>(); - private IWorkflowService service; - private Thread.UncaughtExceptionHandler uncaughtExceptionHandler = - (t, e) -> log.error("uncaught exception", e); - private AtomicBoolean shutdown = new AtomicBoolean(); - - public PollDecisionTaskDispatcher(IWorkflowService service) { - this.service = Objects.requireNonNull(service); - } - - public PollDecisionTaskDispatcher( - IWorkflowService service, Thread.UncaughtExceptionHandler exceptionHandler) { - this.service = Objects.requireNonNull(service); - if (exceptionHandler != null) { - this.uncaughtExceptionHandler = exceptionHandler; - } - } - - @Override - public void process(PollForDecisionTaskResponse t) { - if (isShutdown()) { - throw new RejectedExecutionException("shutdown"); - } - String taskListName = t.getWorkflowExecutionTaskList().getName(); - if (subscribers.containsKey(taskListName)) { - subscribers.get(taskListName).accept(t); - } else { - RespondDecisionTaskFailedRequest request = new RespondDecisionTaskFailedRequest(); - request.setTaskToken(t.getTaskToken()); - request.setCause(DecisionTaskFailedCause.RESET_STICKY_TASKLIST); - String message = - String.format( - "No handler is subscribed for the PollForDecisionTaskResponse.WorkflowExecutionTaskList %s", - taskListName); - request.setDetails(message.getBytes(Charset.defaultCharset())); - log.warn(message); - - try { - service.RespondDecisionTaskFailed(request); - } catch (Exception e) { - uncaughtExceptionHandler.uncaughtException(Thread.currentThread(), e); - } - } - } - - @Override - public boolean hasCapacity() { - return true; - } - - public void subscribe(String taskList, Consumer consumer) { - subscribers.put(taskList, consumer); - } - - @Override - public boolean isShutdown() { - return shutdown.get(); - } - - @Override - public boolean isTerminated() { - return shutdown.get(); - } - - @Override - public void shutdown() { - shutdown.set(true); - } - - @Override - public void shutdownNow() { - shutdown.set(true); - } - - @Override - public void awaitTermination(long timeout, TimeUnit unit) {} -} diff --git a/src/main/java/com/uber/cadence/internal/worker/PollTaskExecutor.java b/src/main/java/com/uber/cadence/internal/worker/PollTaskExecutor.java index 3370f7ebd..f611644d0 100644 --- a/src/main/java/com/uber/cadence/internal/worker/PollTaskExecutor.java +++ b/src/main/java/com/uber/cadence/internal/worker/PollTaskExecutor.java @@ -55,7 +55,7 @@ public interface TaskHandler { 0, options.getTaskExecutorThreadPoolSize(), 1, - TimeUnit.SECONDS, + TimeUnit.MINUTES, new SynchronousQueue<>())); taskExecutor.setThreadFactory( new ExecutorThreadFactory( diff --git a/src/main/java/com/uber/cadence/internal/worker/SingleWorkerOptions.java b/src/main/java/com/uber/cadence/internal/worker/SingleWorkerOptions.java index c7026ff6e..a767ae925 100644 --- a/src/main/java/com/uber/cadence/internal/worker/SingleWorkerOptions.java +++ b/src/main/java/com/uber/cadence/internal/worker/SingleWorkerOptions.java @@ -48,6 +48,7 @@ public static final class Builder { private List contextPropagators; private Tracer tracer; private ExecutorWrapper executorWrapper; + private Duration stickyTaskListScheduleToStartTimeout; private Builder() {} @@ -62,6 +63,7 @@ public Builder(SingleWorkerOptions options) { this.contextPropagators = options.getContextPropagators(); this.tracer = options.getTracer(); this.executorWrapper = options.getExecutorWrapper(); + this.stickyTaskListScheduleToStartTimeout = options.getStickyTaskListScheduleToStartTimeout(); } public Builder setIdentity(String identity) { @@ -115,6 +117,12 @@ public Builder setExecutorWrapper(ExecutorWrapper executorWrapper) { return this; } + public Builder setStickyTaskListScheduleToStartTimeout( + Duration stickyTaskListScheduleToStartTimeout) { + this.stickyTaskListScheduleToStartTimeout = stickyTaskListScheduleToStartTimeout; + return this; + } + public SingleWorkerOptions build() { if (pollerOptions == null) { pollerOptions = @@ -143,7 +151,8 @@ public SingleWorkerOptions build() { enableLoggingInReplay, contextPropagators, tracer, - executorWrapper); + executorWrapper, + stickyTaskListScheduleToStartTimeout); } } @@ -157,6 +166,7 @@ public SingleWorkerOptions build() { private List contextPropagators; private final Tracer tracer; private final ExecutorWrapper executorWrapper; + private final Duration stickyTaskListScheduleToStartTimeout; private SingleWorkerOptions( String identity, @@ -168,7 +178,8 @@ private SingleWorkerOptions( boolean enableLoggingInReplay, List contextPropagators, Tracer tracer, - ExecutorWrapper executorWrapper) { + ExecutorWrapper executorWrapper, + Duration stickyTaskListScheduleToStartTimeout) { this.identity = identity; this.dataConverter = dataConverter; this.taskExecutorThreadPoolSize = taskExecutorThreadPoolSize; @@ -179,6 +190,7 @@ private SingleWorkerOptions( this.contextPropagators = contextPropagators; this.tracer = tracer; this.executorWrapper = executorWrapper; + this.stickyTaskListScheduleToStartTimeout = stickyTaskListScheduleToStartTimeout; } public String getIdentity() { @@ -220,4 +232,8 @@ public Tracer getTracer() { public ExecutorWrapper getExecutorWrapper() { return executorWrapper; } + + public Duration getStickyTaskListScheduleToStartTimeout() { + return stickyTaskListScheduleToStartTimeout; + } } diff --git a/src/main/java/com/uber/cadence/internal/worker/StickyQueueBalancer.java b/src/main/java/com/uber/cadence/internal/worker/StickyQueueBalancer.java new file mode 100644 index 000000000..d7d726ca0 --- /dev/null +++ b/src/main/java/com/uber/cadence/internal/worker/StickyQueueBalancer.java @@ -0,0 +1,75 @@ +/* + * Modifications copyright (C) 2017 Uber Technologies, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not + * use this file except in compliance with the License. A copy of the License is + * located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package com.uber.cadence.internal.worker; + +import com.uber.cadence.TaskListKind; +import javax.annotation.concurrent.ThreadSafe; + +@ThreadSafe +public class StickyQueueBalancer { + + private final int pollersCount; + private final boolean stickyQueueEnabled; + private int stickyPollers = 0; + private int normalPollers = 0; + private long stickyBacklogSize = 0; + + public StickyQueueBalancer(int pollersCount, boolean stickyQueueEnabled) { + this.pollersCount = pollersCount; + this.stickyQueueEnabled = stickyQueueEnabled; + } + + /** @return task list kind that should be used for the next poll */ + public synchronized TaskListKind makePoll() { + if (stickyQueueEnabled) { + // If pollersCount >= stickyBacklogSize > 0 we want to go back to a normal ratio to avoid a + // situation that too many pollers (all of them in the worst case) will open only sticky queue + // polls observing a stickyBacklogSize == 1 for example (which actually can be 0 already at + // that moment) and get stuck causing dip in worker load. + if (stickyBacklogSize > pollersCount || stickyPollers <= normalPollers) { + stickyPollers++; + return TaskListKind.STICKY; + } + } + normalPollers++; + return TaskListKind.NORMAL; + } + + /** @param taskListKind what kind of task list poll was just finished */ + public synchronized void finishPoll(TaskListKind taskListKind) { + switch (taskListKind) { + case NORMAL: + normalPollers--; + break; + case STICKY: + stickyPollers--; + break; + default: + throw new IllegalArgumentException("Invalid task list kind: " + taskListKind); + } + } + + /** + * @param taskListKind what kind of task list poll was just finished + * @param backlogSize backlog size from the poll response + */ + public synchronized void finishPoll(TaskListKind taskListKind, long backlogSize) { + finishPoll(taskListKind); + if (TaskListKind.STICKY.equals(taskListKind)) { + stickyBacklogSize = backlogSize; + } + } +} diff --git a/src/main/java/com/uber/cadence/internal/worker/WorkflowPollTask.java b/src/main/java/com/uber/cadence/internal/worker/WorkflowPollTask.java index b90921b20..8ff188d03 100644 --- a/src/main/java/com/uber/cadence/internal/worker/WorkflowPollTask.java +++ b/src/main/java/com/uber/cadence/internal/worker/WorkflowPollTask.java @@ -31,104 +31,160 @@ import com.uber.m3.util.Duration; import com.uber.m3.util.ImmutableMap; import java.util.Objects; +import java.util.concurrent.Semaphore; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -final class WorkflowPollTask implements Poller.PollTask { +final class WorkflowPollTask implements Poller.PollTask { private static final Logger log = LoggerFactory.getLogger(WorkflowWorker.class); private final Scope metricScope; + private final Scope stickyMetricScope; private final IWorkflowService service; private final String domain; - private final String taskList; - private final TaskListKind taskListKind; private final String identity; + private final String taskList; + private final String stickyTaskListName; + private final Semaphore decisionTaskExecutorSemaphore; + private final StickyQueueBalancer stickyQueueBalancer; WorkflowPollTask( IWorkflowService service, String domain, String taskList, - TaskListKind taskListKind, + String stickyTaskListName, Scope metricScope, - String identity) { - this.identity = Objects.requireNonNull(identity); + String identity, + Semaphore decisionTaskExecutorSemaphore, + StickyQueueBalancer stickyQueueBalancer) { this.service = Objects.requireNonNull(service); this.domain = Objects.requireNonNull(domain); + this.identity = identity; this.taskList = Objects.requireNonNull(taskList); - this.taskListKind = Objects.requireNonNull(taskListKind); + this.stickyTaskListName = stickyTaskListName; this.metricScope = Objects.requireNonNull(metricScope); + this.stickyQueueBalancer = Objects.requireNonNull(stickyQueueBalancer); + this.decisionTaskExecutorSemaphore = Objects.requireNonNull(decisionTaskExecutorSemaphore); + + this.stickyMetricScope = + metricScope.tagged( + new ImmutableMap.Builder(1) + .put(MetricsTag.TASK_LIST, String.format("%s:%s", taskList, "sticky")) + .build()); } @Override - public PollForDecisionTaskResponse poll() throws CadenceError { - metricScope.counter(MetricsType.DECISION_POLL_COUNTER).inc(1); - MetricsEmit.DualStopwatch sw = - MetricsEmit.startLatency( - metricScope, MetricsType.DECISION_POLL_LATENCY, HistogramBuckets.DEFAULT_1MS_100S); - - PollForDecisionTaskRequest pollRequest = new PollForDecisionTaskRequest(); - pollRequest.setDomain(domain); - pollRequest.setIdentity(identity); - pollRequest.setBinaryChecksum(BinaryChecksum.getBinaryChecksum()); + public DecisionTask poll() throws CadenceError { + boolean isSuccessful = false; + try { + decisionTaskExecutorSemaphore.acquire(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + return null; + } - TaskList tl = new TaskList().setName(taskList).setKind(taskListKind); - pollRequest.setTaskList(tl); + TaskListKind taskListKind = stickyQueueBalancer.makePoll(); + boolean isSticky = TaskListKind.STICKY.equals(taskListKind); + PollForDecisionTaskRequest request = + isSticky + ? new PollForDecisionTaskRequest() + .setDomain(domain) + .setIdentity(identity) + .setBinaryChecksum(BinaryChecksum.getBinaryChecksum()) + .setTaskList( + new TaskList() + .setName(stickyTaskListName) + .setKind(TaskListKind.STICKY) + .setBaseName(taskList)) + : new PollForDecisionTaskRequest() + .setDomain(domain) + .setIdentity(identity) + .setBinaryChecksum(BinaryChecksum.getBinaryChecksum()) + .setTaskList(new TaskList().setName(taskList).setKind(TaskListKind.NORMAL)); + Scope scope = isSticky ? stickyMetricScope : metricScope; - if (log.isDebugEnabled()) { - log.debug("poll request begin: " + pollRequest); + log.trace("poll request begin: {}", request); + try { + PollForDecisionTaskResponse response = doPoll(request, scope); + if (response == null) { + return null; + } + isSuccessful = true; + stickyQueueBalancer.finishPoll(taskListKind, response.getBacklogCountHint()); + return new DecisionTask(response, decisionTaskExecutorSemaphore::release); + } finally { + if (!isSuccessful) { + decisionTaskExecutorSemaphore.release(); + stickyQueueBalancer.finishPoll(taskListKind); + } } + } + + private PollForDecisionTaskResponse doPoll(PollForDecisionTaskRequest request, Scope scope) + throws CadenceError { + scope.counter(MetricsType.DECISION_POLL_COUNTER).inc(1); + MetricsEmit.DualStopwatch sw = + MetricsEmit.startLatency( + scope, MetricsType.DECISION_POLL_LATENCY, HistogramBuckets.DEFAULT_1MS_100S); + PollForDecisionTaskResponse result; try { - result = service.PollForDecisionTask(pollRequest); - } catch (InternalServiceError e) { - metricScope - .tagged(ImmutableMap.of(MetricsTag.CAUSE, INTERNAL_SERVICE_ERROR)) - .counter(MetricsType.DECISION_POLL_TRANSIENT_FAILED_COUNTER) - .inc(1); - throw e; - } catch (ServiceBusyError e) { - metricScope - .tagged(ImmutableMap.of(MetricsTag.CAUSE, SERVICE_BUSY)) - .counter(MetricsType.DECISION_POLL_TRANSIENT_FAILED_COUNTER) - .inc(1); - throw e; - } catch (CadenceError e) { - metricScope.counter(MetricsType.DECISION_POLL_FAILED_COUNTER).inc(1); - throw e; - } + if (log.isDebugEnabled()) { + log.debug("poll request begin: " + request); + } + try { + result = service.PollForDecisionTask(request); + } catch (InternalServiceError e) { + scope + .tagged(ImmutableMap.of(MetricsTag.CAUSE, INTERNAL_SERVICE_ERROR)) + .counter(MetricsType.DECISION_POLL_TRANSIENT_FAILED_COUNTER) + .inc(1); + throw e; + } catch (ServiceBusyError e) { + scope + .tagged(ImmutableMap.of(MetricsTag.CAUSE, SERVICE_BUSY)) + .counter(MetricsType.DECISION_POLL_TRANSIENT_FAILED_COUNTER) + .inc(1); + throw e; + } catch (CadenceError e) { + scope.counter(MetricsType.DECISION_POLL_FAILED_COUNTER).inc(1); + throw e; + } - if (log.isDebugEnabled()) { - log.debug( - "poll request returned decision task: workflowType=" - + result.getWorkflowType() - + ", workflowExecution=" - + result.getWorkflowExecution() - + ", startedEventId=" - + result.getStartedEventId() - + ", previousStartedEventId=" - + result.getPreviousStartedEventId() - + (result.getQuery() != null - ? ", queryType=" + result.getQuery().getQueryType() - : "")); - } + if (log.isDebugEnabled()) { + log.debug( + "poll request returned decision task: workflowType=" + + result.getWorkflowType() + + ", workflowExecution=" + + result.getWorkflowExecution() + + ", startedEventId=" + + result.getStartedEventId() + + ", previousStartedEventId=" + + result.getPreviousStartedEventId() + + (result.getQuery() != null + ? ", queryType=" + result.getQuery().getQueryType() + : "")); + } - if (result == null || result.getTaskToken() == null) { - metricScope.counter(MetricsType.DECISION_POLL_NO_TASK_COUNTER).inc(1); - return null; - } + if (result == null || result.getTaskToken() == null) { + scope.counter(MetricsType.DECISION_POLL_NO_TASK_COUNTER).inc(1); + return null; + } - Scope metricsScope = - metricScope.tagged( - ImmutableMap.of(MetricsTag.WORKFLOW_TYPE, result.getWorkflowType().getName())); - metricsScope.counter(MetricsType.DECISION_POLL_SUCCEED_COUNTER).inc(1); - Duration scheduledToStartLatency = - Duration.ofNanos(result.getStartedTimestamp() - result.getScheduledTimestamp()); - MetricsEmit.emitLatency( - metricsScope, - MetricsType.DECISION_SCHEDULED_TO_START_LATENCY, - scheduledToStartLatency, - HistogramBuckets.DEFAULT_1MS_100S); - sw.stop(); + Scope metricsScope = + scope.tagged( + ImmutableMap.of(MetricsTag.WORKFLOW_TYPE, result.getWorkflowType().getName())); + metricsScope.counter(MetricsType.DECISION_POLL_SUCCEED_COUNTER).inc(1); + Duration scheduledToStartLatency = + Duration.ofNanos(result.getStartedTimestamp() - result.getScheduledTimestamp()); + MetricsEmit.emitLatency( + metricsScope, + MetricsType.DECISION_SCHEDULED_TO_START_LATENCY, + scheduledToStartLatency, + HistogramBuckets.DEFAULT_1MS_100S); + } finally { + sw.stop(); + } return result; } } diff --git a/src/main/java/com/uber/cadence/internal/worker/WorkflowPollTaskFactory.java b/src/main/java/com/uber/cadence/internal/worker/WorkflowPollTaskFactory.java deleted file mode 100644 index 19380e7d0..000000000 --- a/src/main/java/com/uber/cadence/internal/worker/WorkflowPollTaskFactory.java +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Copyright 2012-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Modifications copyright (C) 2017 Uber Technologies, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not - * use this file except in compliance with the License. A copy of the License is - * located at - * - * http://aws.amazon.com/apache2.0 - * - * or in the "license" file accompanying this file. This file is distributed on - * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package com.uber.cadence.internal.worker; - -import com.uber.cadence.PollForDecisionTaskResponse; -import com.uber.cadence.TaskListKind; -import com.uber.cadence.serviceclient.IWorkflowService; -import com.uber.m3.tally.Scope; -import java.util.Objects; -import java.util.function.Supplier; - -public class WorkflowPollTaskFactory - implements Supplier> { - - private final IWorkflowService service; - private final String domain; - private final String taskList; - private final TaskListKind taskListKind; - private final Scope metricScope; - private final String identity; - - public WorkflowPollTaskFactory( - IWorkflowService service, - String domain, - String taskList, - TaskListKind taskListKind, - Scope metricScope, - String identity) { - this.service = Objects.requireNonNull(service, "service should not be null"); - this.domain = Objects.requireNonNull(domain, "domain should not be null"); - this.taskList = Objects.requireNonNull(taskList, "taskList should not be null"); - this.taskListKind = Objects.requireNonNull(taskListKind, "taskList should not be null"); - this.metricScope = Objects.requireNonNull(metricScope, "metricScope should not be null"); - this.identity = Objects.requireNonNull(identity, "identity should not be null"); - } - - @Override - public Poller.PollTask get() { - return new WorkflowPollTask(service, domain, taskList, taskListKind, metricScope, identity); - } -} diff --git a/src/main/java/com/uber/cadence/internal/worker/WorkflowWorker.java b/src/main/java/com/uber/cadence/internal/worker/WorkflowWorker.java index b4b7b2383..01b20bd83 100644 --- a/src/main/java/com/uber/cadence/internal/worker/WorkflowWorker.java +++ b/src/main/java/com/uber/cadence/internal/worker/WorkflowWorker.java @@ -37,13 +37,16 @@ import java.util.Arrays; import java.util.List; import java.util.Objects; +import java.util.concurrent.Semaphore; import java.util.concurrent.locks.Lock; -import java.util.function.Consumer; import java.util.function.Function; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import org.slf4j.MDC; -public final class WorkflowWorker extends SuspendableWorkerBase - implements Consumer { +public final class WorkflowWorker extends SuspendableWorkerBase { + + private static final Logger log = LoggerFactory.getLogger(WorkflowWorker.class); private static final String POLL_THREAD_NAME_PREFIX = "Workflow Poller taskList="; private final DecisionTaskHandler handler; @@ -54,7 +57,7 @@ public final class WorkflowWorker extends SuspendableWorkerBase private final String stickyTaskListName; private final WorkflowRunLockManager runLocks = new WorkflowRunLockManager(); private final Function ldaTaskPoller; - private PollTaskExecutor pollTaskExecutor; + private PollTaskExecutor pollTaskExecutor; public WorkflowWorker( IWorkflowService service, @@ -79,14 +82,27 @@ public WorkflowWorker( POLL_THREAD_NAME_PREFIX + "\"" + taskList + "\", domain=\"" + domain + "\"") .build(); } + // ensure at least 2 poll threads + if (pollerOptions.getPollThreadCount() < 2 && stickyTaskListName != null) { + pollerOptions = PollerOptions.newBuilder(pollerOptions).setPollThreadCount(2).build(); + } this.options = SingleWorkerOptions.newBuilder(options).setPollerOptions(pollerOptions).build(); } @Override public void start() { if (handler.isAnyTypeSupported()) { + Semaphore decisionTaskExecutorSemaphore = + new Semaphore(options.getTaskExecutorThreadPoolSize()); + pollTaskExecutor = new PollTaskExecutor<>(domain, taskList, options, new TaskHandlerImpl(handler)); + + // Create sticky queue balancer + StickyQueueBalancer stickyQueueBalancer = + new StickyQueueBalancer( + options.getPollerOptions().getPollThreadCount(), stickyTaskListName != null); + SuspendableWorker poller = new Poller<>( options.getIdentity(), @@ -94,15 +110,18 @@ public void start() { service, domain, taskList, - TaskListKind.NORMAL, + stickyTaskListName, options.getMetricsScope(), - options.getIdentity()), + options.getIdentity(), + decisionTaskExecutorSemaphore, + stickyQueueBalancer), pollTaskExecutor, options.getPollerOptions(), options.getMetricsScope(), options.getExecutorWrapper()); poller.start(); setPoller(poller); + options.getMetricsScope().counter(MetricsType.WORKER_START_COUNTER).inc(1); } } @@ -170,13 +189,7 @@ private byte[] queryWorkflowExecution( throw new RuntimeException("Query returned wrong response: " + result); } - @Override - public void accept(PollForDecisionTaskResponse pollForDecisionTaskResponse) { - pollTaskExecutor.process(pollForDecisionTaskResponse); - } - - private class TaskHandlerImpl - implements PollTaskExecutor.TaskHandler { + private class TaskHandlerImpl implements PollTaskExecutor.TaskHandler { final DecisionTaskHandler handler; @@ -185,19 +198,21 @@ private TaskHandlerImpl(DecisionTaskHandler handler) { } @Override - public void handle(PollForDecisionTaskResponse task) throws Exception { + public void handle(DecisionTask task) throws Exception { + PollForDecisionTaskResponse response = task.getResponse(); Scope metricsScope = options .getMetricsScope() - .tagged(ImmutableMap.of(MetricsTag.WORKFLOW_TYPE, task.getWorkflowType().getName())); + .tagged( + ImmutableMap.of(MetricsTag.WORKFLOW_TYPE, response.getWorkflowType().getName())); - MDC.put(LoggerTag.WORKFLOW_ID, task.getWorkflowExecution().getWorkflowId()); - MDC.put(LoggerTag.WORKFLOW_TYPE, task.getWorkflowType().getName()); - MDC.put(LoggerTag.RUN_ID, task.getWorkflowExecution().getRunId()); + MDC.put(LoggerTag.WORKFLOW_ID, response.getWorkflowExecution().getWorkflowId()); + MDC.put(LoggerTag.WORKFLOW_TYPE, response.getWorkflowType().getName()); + MDC.put(LoggerTag.RUN_ID, response.getWorkflowExecution().getRunId()); Lock runLock = null; if (!Strings.isNullOrEmpty(stickyTaskListName)) { - runLock = runLocks.getLockForLocking(task.getWorkflowExecution().getRunId()); + runLock = runLocks.getLockForLocking(response.getWorkflowExecution().getRunId()); runLock.lock(); } @@ -207,7 +222,7 @@ public void handle(PollForDecisionTaskResponse task) throws Exception { metricsScope, MetricsType.DECISION_EXECUTION_LATENCY, HistogramBuckets.DEFAULT_1MS_100S); - DecisionTaskHandler.Result response = handler.handleDecisionTask(task); + DecisionTaskHandler.Result handlerResponse = handler.handleDecisionTask(response); sw.stop(); sw = @@ -215,7 +230,7 @@ public void handle(PollForDecisionTaskResponse task) throws Exception { metricsScope, MetricsType.DECISION_RESPONSE_LATENCY, HistogramBuckets.DEFAULT_1MS_100S); - sendReply(service, task, response); + sendReply(service, response, handlerResponse); sw.stop(); metricsScope.counter(MetricsType.DECISION_TASK_COMPLETED_COUNTER).inc(1); @@ -224,15 +239,17 @@ public void handle(PollForDecisionTaskResponse task) throws Exception { MDC.remove(LoggerTag.WORKFLOW_TYPE); MDC.remove(LoggerTag.RUN_ID); + task.getCompletionCallback().apply(); + if (runLock != null) { - runLocks.unlock(task.getWorkflowExecution().getRunId()); + runLocks.unlock(response.getWorkflowExecution().getRunId()); } } } @Override - public Throwable wrapFailure(PollForDecisionTaskResponse task, Throwable failure) { - WorkflowExecution execution = task.getWorkflowExecution(); + public Throwable wrapFailure(DecisionTask task, Throwable failure) { + WorkflowExecution execution = task.getResponse().getWorkflowExecution(); return new RuntimeException( "Failure processing decision task. WorkflowID=" + execution.getWorkflowId() diff --git a/src/main/java/com/uber/cadence/worker/Worker.java b/src/main/java/com/uber/cadence/worker/Worker.java index 9d4e2d626..63e102b84 100644 --- a/src/main/java/com/uber/cadence/worker/Worker.java +++ b/src/main/java/com/uber/cadence/worker/Worker.java @@ -23,7 +23,6 @@ import com.uber.cadence.client.WorkflowClient; import com.uber.cadence.common.WorkflowExecutionHistory; import com.uber.cadence.context.ContextPropagator; -import com.uber.cadence.converter.DataConverter; import com.uber.cadence.internal.common.InternalUtils; import com.uber.cadence.internal.metrics.MetricsTag; import com.uber.cadence.internal.replay.DeciderCache; @@ -36,9 +35,9 @@ import com.uber.m3.tally.Scope; import com.uber.m3.util.ImmutableMap; import io.opentracing.noop.NoopTracer; -import java.time.Duration; import java.util.List; import java.util.Objects; +import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; @@ -62,7 +61,12 @@ public final class Worker implements Suspendable { * @param client client to the Cadence Service endpoint. * @param taskList task list name worker uses to poll. It uses this name for both decision and * activity task list polls. - * @param options Options (like {@link DataConverter} override) for configuring worker. + * @param factoryOptions Options for configuring worker factory. + * @param options Options for configuring worker. + * @param cache Decider cache to use for sticky execution. + * @param enableStickyExecution Whether to enable sticky execution. + * @param threadPoolExecutor Thread pool executor to use for workflow and activity tasks. + * @param contextPropagators Context propagators to use for workflow and activity tasks. */ Worker( WorkflowClient client, @@ -70,8 +74,7 @@ public final class Worker implements Suspendable { WorkerFactoryOptions factoryOptions, WorkerOptions options, DeciderCache cache, - String stickyTaskListName, - Duration stickyDecisionScheduleToStartTimeout, + boolean enableStickyExecution, ThreadPoolExecutor threadPoolExecutor, List contextPropagators) { this.taskList = Objects.requireNonNull(taskList); @@ -119,6 +122,8 @@ public final class Worker implements Suspendable { .setContextPropagators(contextPropagators) .setTracer(options.getTracer()) .setExecutorWrapper(factoryOptions.getExecutorWrapper()) + .setStickyTaskListScheduleToStartTimeout( + options.getStickyTaskListScheduleToStartTimeout()) .build(); SingleWorkerOptions localActivityOptions = SingleWorkerOptions.newBuilder() @@ -142,11 +147,17 @@ public final class Worker implements Suspendable { localActivityOptions, activityOptions, cache, - stickyTaskListName, - stickyDecisionScheduleToStartTimeout, + enableStickyExecution ? getStickyTaskListName(client.getOptions().getIdentity()) : null, threadPoolExecutor); } + private static String getStickyTaskListName(String workerIdentity) { + // Unique id is needed to avoid collisions with other workers that may be created for the same + // task list and with the same identity. + UUID uniqueId = UUID.randomUUID(); + return String.format("%s:%s", workerIdentity, uniqueId); + } + SyncWorkflowWorker getWorkflowWorker() { return workflowWorker; } diff --git a/src/main/java/com/uber/cadence/worker/WorkerFactory.java b/src/main/java/com/uber/cadence/worker/WorkerFactory.java index 19e716082..764f7e08c 100644 --- a/src/main/java/com/uber/cadence/worker/WorkerFactory.java +++ b/src/main/java/com/uber/cadence/worker/WorkerFactory.java @@ -21,8 +21,6 @@ import com.google.common.base.MoreObjects; import com.google.common.base.Preconditions; import com.google.common.base.Strings; -import com.uber.cadence.PollForDecisionTaskResponse; -import com.uber.cadence.TaskListKind; import com.uber.cadence.client.WorkflowClient; import com.uber.cadence.converter.DataConverter; import com.uber.cadence.converter.JsonDataConverter; @@ -32,12 +30,9 @@ import com.uber.cadence.internal.worker.*; import com.uber.m3.tally.Scope; import com.uber.m3.util.ImmutableMap; -import java.net.InetAddress; -import java.net.UnknownHostException; import java.util.ArrayList; import java.util.List; import java.util.Objects; -import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.SynchronousQueue; import java.util.concurrent.ThreadPoolExecutor; @@ -64,14 +59,10 @@ public static WorkerFactory newInstance( private final List workers = new ArrayList<>(); private final WorkflowClient workflowClient; - // Guarantee uniqueness for stickyTaskListName when multiple factories - private final UUID stickyTasklistRandomId = UUID.randomUUID(); private final ThreadPoolExecutor workflowThreadPool; private final AtomicInteger workflowThreadCounter = new AtomicInteger(); private final WorkerFactoryOptions factoryOptions; - private Poller stickyPoller; - private PollDecisionTaskDispatcher dispatcher; private DeciderCache cache; private State state = State.Initial; @@ -79,9 +70,7 @@ public static WorkerFactory newInstance( private final String statusErrorMessage = "attempted to %s while in %s state. Acceptable States: %s"; private static final Logger log = LoggerFactory.getLogger(WorkerFactory.class); - private static final String STICKY_TASK_LIST_PREFIX = "sticky"; - private static final String STICKY_TASK_LIST_METRIC_TAG = "__" + STICKY_TASK_LIST_PREFIX + "__"; - private static final String POLL_THREAD_NAME = "Sticky Task Poller"; + private static final String STICKY_TASK_LIST_METRIC_TAG = "__sticky__"; /** * Creates a factory. Workers will be connect to the cadence-server using the workflowService @@ -115,7 +104,8 @@ public WorkerFactory(WorkflowClient workflowClient, WorkerFactoryOptions factory // initialize the JsonDataConverter with the metrics scope JsonDataConverter.setMetricsScope(workflowClient.getOptions().getMetricsScope()); - Scope stickyScope = + // Tag the cache metrics scope with a constant sticky tag since cache is shared across workers + Scope metricsScope = workflowClient .getOptions() .getMetricsScope() @@ -125,27 +115,7 @@ public WorkerFactory(WorkflowClient workflowClient, WorkerFactoryOptions factory workflowClient.getOptions().getDomain(), MetricsTag.TASK_LIST, STICKY_TASK_LIST_METRIC_TAG)); - - this.cache = new DeciderCache(this.factoryOptions.getCacheMaximumSize(), stickyScope); - dispatcher = new PollDecisionTaskDispatcher(workflowClient.getService()); - stickyPoller = - new Poller<>( - workflowClient.getOptions().getIdentity(), - new WorkflowPollTaskFactory( - workflowClient.getService(), - workflowClient.getOptions().getDomain(), - getStickyTaskListName(), - TaskListKind.STICKY, - stickyScope, - workflowClient.getOptions().getIdentity()) - .get(), - dispatcher, - PollerOptions.newBuilder() - .setPollThreadNamePrefix(POLL_THREAD_NAME) - .setPollThreadCount(this.factoryOptions.getStickyPollerCount()) - .build(), - stickyScope, - factoryOptions.getExecutorWrapper()); + this.cache = new DeciderCache(this.factoryOptions.getCacheMaximumSize(), metricsScope); } /** @@ -185,15 +155,11 @@ public synchronized Worker newWorker(String taskList, WorkerOptions options) { factoryOptions, options, cache, - getStickyTaskListName(), - factoryOptions.getStickyTaskScheduleToStartTimeout(), + !factoryOptions.isDisableStickyExecution(), workflowThreadPool, workflowClient.getOptions().getContextPropagators()); workers.add(worker); - if (!this.factoryOptions.isDisableStickyExecution()) { - dispatcher.subscribe(taskList, worker.getWorkflowWorker()); - } return worker; } @@ -214,10 +180,6 @@ public synchronized void start() { for (Worker worker : workers) { worker.start(); } - - if (stickyPoller != null) { - stickyPoller.start(); - } } /** Was {@link #start()} called. */ @@ -238,11 +200,6 @@ public synchronized boolean isTerminated() { if (state != State.Shutdown) { return false; } - if (stickyPoller != null) { - if (!stickyPoller.isTerminated()) { - return false; - } - } for (Worker worker : workers) { if (!worker.isTerminated()) { return false; @@ -267,11 +224,6 @@ public WorkflowClient getWorkflowClient() { public synchronized void shutdown() { log.info("shutdown"); state = State.Shutdown; - if (stickyPoller != null) { - stickyPoller.shutdown(); - // To ensure that it doesn't get new tasks before workers are shutdown. - stickyPoller.awaitTermination(1, TimeUnit.SECONDS); - } for (Worker worker : workers) { worker.shutdown(); } @@ -289,11 +241,6 @@ public synchronized void shutdown() { public synchronized void shutdownNow() { log.info("shutdownNow"); state = State.Shutdown; - if (stickyPoller != null) { - stickyPoller.shutdownNow(); - // To ensure that it doesn't get new tasks before workers are shutdown. - stickyPoller.awaitTermination(1, TimeUnit.SECONDS); - } for (Worker worker : workers) { worker.shutdownNow(); } @@ -320,7 +267,6 @@ public CompletableFuture isHealthy() { public void awaitTermination(long timeout, TimeUnit unit) { log.debug("awaitTermination begin"); long timeoutMillis = unit.toMillis(timeout); - timeoutMillis = InternalUtils.awaitTermination(stickyPoller, timeoutMillis); for (Worker worker : workers) { long t = timeoutMillis; // closure needs immutable value timeoutMillis = @@ -335,21 +281,6 @@ DeciderCache getCache() { return this.cache; } - private String getHostName() { - try { - return InetAddress.getLocalHost().getHostName(); - } catch (UnknownHostException e) { - return "UnknownHost"; - } - } - - @VisibleForTesting - String getStickyTaskListName() { - return this.factoryOptions.isDisableStickyExecution() - ? null - : String.format("%s:%s:%s", STICKY_TASK_LIST_PREFIX, getHostName(), stickyTasklistRandomId); - } - public synchronized void suspendPolling() { if (state != State.Started) { return; @@ -357,9 +288,6 @@ public synchronized void suspendPolling() { log.info("suspendPolling"); state = State.Suspended; - if (stickyPoller != null) { - stickyPoller.suspendPolling(); - } for (Worker worker : workers) { worker.suspendPolling(); } @@ -372,9 +300,6 @@ public synchronized void resumePolling() { log.info("resumePolling"); state = State.Started; - if (stickyPoller != null) { - stickyPoller.resumePolling(); - } for (Worker worker : workers) { worker.resumePolling(); } diff --git a/src/main/java/com/uber/cadence/worker/WorkerFactoryOptions.java b/src/main/java/com/uber/cadence/worker/WorkerFactoryOptions.java index a1bb5dd6e..89b170b6a 100644 --- a/src/main/java/com/uber/cadence/worker/WorkerFactoryOptions.java +++ b/src/main/java/com/uber/cadence/worker/WorkerFactoryOptions.java @@ -18,7 +18,6 @@ package com.uber.cadence.worker; import com.google.common.base.Preconditions; -import java.time.Duration; public class WorkerFactoryOptions { public static Builder newBuilder() { @@ -26,10 +25,7 @@ public static Builder newBuilder() { } private static final WorkerFactoryOptions DEFAULT_INSTANCE; - private static final int DEFAULT_STICKY_POLLER_COUNT = 5; private static final int DEFAULT_STICKY_CACHE_SIZE = 600; - private static final Duration DEFAULT_STICKY_TASK_SCHEDULE_TO_START_TIMEOUT = - Duration.ofSeconds(5); private static final int DEFAULT_MAX_WORKFLOW_THREAD_COUNT = 600; static { @@ -42,12 +38,9 @@ static WorkerFactoryOptions defaultInstance() { public static class Builder { private boolean disableStickyExecution; - private Duration stickyTaskScheduleToStartTimeout = - DEFAULT_STICKY_TASK_SCHEDULE_TO_START_TIMEOUT; private int stickyCacheSize = DEFAULT_STICKY_CACHE_SIZE; private int maxWorkflowThreadCount = DEFAULT_MAX_WORKFLOW_THREAD_COUNT; private boolean enableLoggingInReplay; - private int stickyPollerCount = DEFAULT_STICKY_POLLER_COUNT; private ExecutorWrapper executorWrapper = ExecutorWrapper.newDefaultInstance(); private Builder() {} @@ -84,24 +77,6 @@ public Builder setMaxWorkflowThreadCount(int maxWorkflowThreadCount) { return this; } - /** - * Timeout for sticky workflow decision to be picked up by the host assigned to it. Once it - * times out then it can be picked up by any worker. Default value is 5 seconds. - */ - public Builder setStickyTaskScheduleToStartTimeout(Duration stickyTaskScheduleToStartTimeout) { - this.stickyTaskScheduleToStartTimeout = stickyTaskScheduleToStartTimeout; - return this; - } - - /** - * PollerOptions for poller responsible for polling for decisions for workflows cached by all - * workers created by this factory. - */ - public Builder setStickyPollerCount(int stickyPollerCount) { - this.stickyPollerCount = stickyPollerCount; - return this; - } - public Builder setEnableLoggingInReplay(boolean enableLoggingInReplay) { this.enableLoggingInReplay = enableLoggingInReplay; return this; @@ -117,8 +92,6 @@ public WorkerFactoryOptions build() { disableStickyExecution, stickyCacheSize, maxWorkflowThreadCount, - stickyTaskScheduleToStartTimeout, - stickyPollerCount, enableLoggingInReplay, executorWrapper); } @@ -127,7 +100,6 @@ public WorkerFactoryOptions build() { private final boolean disableStickyExecution; private final int cacheMaximumSize; private final int maxWorkflowThreadCount; - private Duration stickyTaskScheduleToStartTimeout; private boolean enableLoggingInReplay; private int stickyPollerCount; private ExecutorWrapper executorWrapper; @@ -136,8 +108,6 @@ private WorkerFactoryOptions( boolean disableStickyExecution, int cacheMaximumSize, int maxWorkflowThreadCount, - Duration stickyTaskScheduleToStartTimeout, - int stickyPollerCount, boolean enableLoggingInReplay, ExecutorWrapper executorWrapper) { Preconditions.checkArgument(cacheMaximumSize > 0, "cacheMaximumSize should be greater than 0"); @@ -147,9 +117,7 @@ private WorkerFactoryOptions( this.disableStickyExecution = disableStickyExecution; this.cacheMaximumSize = cacheMaximumSize; this.maxWorkflowThreadCount = maxWorkflowThreadCount; - this.stickyPollerCount = stickyPollerCount; this.enableLoggingInReplay = enableLoggingInReplay; - this.stickyTaskScheduleToStartTimeout = stickyTaskScheduleToStartTimeout; this.executorWrapper = executorWrapper; } @@ -169,14 +137,6 @@ public boolean isEnableLoggingInReplay() { return enableLoggingInReplay; } - public int getStickyPollerCount() { - return stickyPollerCount; - } - - public Duration getStickyTaskScheduleToStartTimeout() { - return stickyTaskScheduleToStartTimeout; - } - public ExecutorWrapper getExecutorWrapper() { return executorWrapper; } diff --git a/src/main/java/com/uber/cadence/worker/WorkerOptions.java b/src/main/java/com/uber/cadence/worker/WorkerOptions.java index 5ff073520..1650048fc 100644 --- a/src/main/java/com/uber/cadence/worker/WorkerOptions.java +++ b/src/main/java/com/uber/cadence/worker/WorkerOptions.java @@ -21,6 +21,7 @@ import com.uber.cadence.workflow.WorkflowInterceptor; import io.opentracing.Tracer; import io.opentracing.noop.NoopTracerFactory; +import java.time.Duration; import java.util.Objects; import java.util.function.Function; @@ -39,6 +40,8 @@ public static WorkerOptions defaultInstance() { private static final WorkerOptions DEFAULT_INSTANCE; + static final Duration DEFAULT_STICKY_TASK_SCHEDULE_TO_START_TIMEOUT = Duration.ofSeconds(5); + static { DEFAULT_INSTANCE = WorkerOptions.newBuilder().build(); } @@ -53,6 +56,8 @@ public static final class Builder { private PollerOptions activityPollerOptions; private PollerOptions workflowPollerOptions; private Function interceptorFactory = (n) -> n; + private Duration stickyTaskListScheduleToStartTimeout = + DEFAULT_STICKY_TASK_SCHEDULE_TO_START_TIMEOUT; // by default NoopTracer private Tracer tracer = NoopTracerFactory.create(); @@ -68,6 +73,7 @@ private Builder(WorkerOptions options) { this.activityPollerOptions = options.activityPollerOptions; this.workflowPollerOptions = options.workflowPollerOptions; this.interceptorFactory = options.interceptorFactory; + this.stickyTaskListScheduleToStartTimeout = options.stickyTaskListScheduleToStartTimeout; this.tracer = options.tracer; } @@ -151,6 +157,19 @@ public Builder setTracer(Tracer tracer) { return this; } + /** + * Timeout for sticky decision task to be picked up by the host assigned to it. Once it times + * out then it can be picked up by any worker. Default value is 5 seconds. + */ + public Builder setStickyTaskListScheduleToStartTimeout( + Duration stickyTaskListScheduleToStartTimeout) { + if (stickyTaskListScheduleToStartTimeout == null) { + stickyTaskListScheduleToStartTimeout = DEFAULT_STICKY_TASK_SCHEDULE_TO_START_TIMEOUT; + } + this.stickyTaskListScheduleToStartTimeout = stickyTaskListScheduleToStartTimeout; + return this; + } + public WorkerOptions build() { return new WorkerOptions( workerActivitiesPerSecond, @@ -161,7 +180,8 @@ public WorkerOptions build() { activityPollerOptions, workflowPollerOptions, interceptorFactory, - tracer); + tracer, + stickyTaskListScheduleToStartTimeout); } } @@ -174,6 +194,7 @@ public WorkerOptions build() { private final PollerOptions workflowPollerOptions; private final Function interceptorFactory; private final Tracer tracer; + private final Duration stickyTaskListScheduleToStartTimeout; private WorkerOptions( double workerActivitiesPerSecond, @@ -184,7 +205,8 @@ private WorkerOptions( PollerOptions activityPollerOptions, PollerOptions workflowPollerOptions, Function interceptorFactory, - Tracer tracer) { + Tracer tracer, + Duration stickyTaskListScheduleToStartTimeout) { this.workerActivitiesPerSecond = workerActivitiesPerSecond; this.maxConcurrentActivityExecutionSize = maxConcurrentActivityExecutionSize; this.maxConcurrentWorkflowExecutionSize = maxConcurrentWorkflowExecutionSize; @@ -194,6 +216,7 @@ private WorkerOptions( this.workflowPollerOptions = workflowPollerOptions; this.interceptorFactory = interceptorFactory; this.tracer = tracer; + this.stickyTaskListScheduleToStartTimeout = stickyTaskListScheduleToStartTimeout; } public double getWorkerActivitiesPerSecond() { @@ -232,6 +255,10 @@ public Tracer getTracer() { return tracer; } + public Duration getStickyTaskListScheduleToStartTimeout() { + return stickyTaskListScheduleToStartTimeout; + } + @Override public String toString() { return "WorkerOptions{" @@ -249,6 +276,8 @@ public String toString() { + activityPollerOptions + ", workflowPollerOptions=" + workflowPollerOptions + + ", stickyTaskListScheduleToStartTimeout=" + + stickyTaskListScheduleToStartTimeout + '}'; } } diff --git a/src/test/java/com/uber/cadence/internal/tracing/StartWorkflowTest.java b/src/test/java/com/uber/cadence/internal/tracing/StartWorkflowTest.java index a0b1cb904..17c78e2ae 100644 --- a/src/test/java/com/uber/cadence/internal/tracing/StartWorkflowTest.java +++ b/src/test/java/com/uber/cadence/internal/tracing/StartWorkflowTest.java @@ -177,7 +177,8 @@ public void testStartMultipleWorkflowGRPC() { Worker worker; worker = workerFactory.newWorker( - TASK_LIST, WorkerOptions.newBuilder().setMaxConcurrentWorkflowExecutionSize(2).build()); + TASK_LIST, + WorkerOptions.newBuilder().setMaxConcurrentWorkflowExecutionSize(20).build()); worker.registerActivitiesImplementations(new TestActivityImpl(mockTracer, true)); worker.registerWorkflowImplementationTypes(TestWorkflowImpl.class, DoubleWorkflowImpl.class); workerFactory.start(); diff --git a/src/test/java/com/uber/cadence/internal/worker/PollDecisionTaskDispatcherTests.java b/src/test/java/com/uber/cadence/internal/worker/PollDecisionTaskDispatcherTests.java deleted file mode 100644 index cff5c3a1a..000000000 --- a/src/test/java/com/uber/cadence/internal/worker/PollDecisionTaskDispatcherTests.java +++ /dev/null @@ -1,152 +0,0 @@ -/* - * Copyright 2012-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Modifications copyright (C) 2017 Uber Technologies, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not - * use this file except in compliance with the License. A copy of the License is - * located at - * - * http://aws.amazon.com/apache2.0 - * - * or in the "license" file accompanying this file. This file is distributed on - * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package com.uber.cadence.internal.worker; - -import static junit.framework.TestCase.*; -import static org.mockito.Mockito.*; - -import ch.qos.logback.classic.Level; -import ch.qos.logback.classic.LoggerContext; -import ch.qos.logback.classic.spi.ILoggingEvent; -import ch.qos.logback.core.read.ListAppender; -import com.uber.cadence.PollForDecisionTaskResponse; -import com.uber.cadence.TaskList; -import com.uber.cadence.internal.testservice.TestWorkflowService; -import com.uber.cadence.serviceclient.IWorkflowService; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.Consumer; -import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -public class PollDecisionTaskDispatcherTests { - - LoggerContext context = (LoggerContext) LoggerFactory.getILoggerFactory(); - ch.qos.logback.classic.Logger logger = context.getLogger(Logger.ROOT_LOGGER_NAME); - - @Test - public void pollDecisionTasksAreDispatchedBasedOnTaskListName() { - - // Arrange - AtomicBoolean handled = new AtomicBoolean(false); - Consumer handler = r -> handled.set(true); - - PollDecisionTaskDispatcher dispatcher = - new PollDecisionTaskDispatcher(new TestWorkflowService()); - dispatcher.subscribe("tasklist1", handler); - - // Act - PollForDecisionTaskResponse response = CreatePollForDecisionTaskResponse("tasklist1"); - dispatcher.process(response); - - // Assert - assertTrue(handled.get()); - } - - @Test - public void pollDecisionTasksAreDispatchedToTheCorrectHandler() { - - // Arrange - AtomicBoolean handled = new AtomicBoolean(false); - AtomicBoolean handled2 = new AtomicBoolean(false); - - Consumer handler = r -> handled.set(true); - Consumer handler2 = r -> handled2.set(true); - - PollDecisionTaskDispatcher dispatcher = - new PollDecisionTaskDispatcher(new TestWorkflowService()); - dispatcher.subscribe("tasklist1", handler); - dispatcher.subscribe("tasklist2", handler2); - - // Act - PollForDecisionTaskResponse response = CreatePollForDecisionTaskResponse("tasklist1"); - dispatcher.process(response); - - // Assert - assertTrue(handled.get()); - assertFalse(handled2.get()); - } - - @Test - public void handlersGetOverwrittenWhenRegisteredForTheSameTaskList() { - - // Arrange - AtomicBoolean handled = new AtomicBoolean(false); - AtomicBoolean handled2 = new AtomicBoolean(false); - - Consumer handler = r -> handled.set(true); - Consumer handler2 = r -> handled2.set(true); - - PollDecisionTaskDispatcher dispatcher = - new PollDecisionTaskDispatcher(new TestWorkflowService()); - dispatcher.subscribe("tasklist1", handler); - dispatcher.subscribe("tasklist1", handler2); - - // Act - PollForDecisionTaskResponse response = CreatePollForDecisionTaskResponse("tasklist1"); - dispatcher.process(response); - - // Assert - assertTrue(handled2.get()); - assertFalse(handled.get()); - } - - @Test - public void aWarningIsLoggedAndDecisionTaskIsFailedWhenNoHandlerIsRegisteredForTheTaskList() - throws Exception { - - // Arrange - ListAppender appender = new ListAppender<>(); - appender.setContext(context); - appender.start(); - logger.addAppender(appender); - - AtomicBoolean handled = new AtomicBoolean(false); - Consumer handler = r -> handled.set(true); - - IWorkflowService mockService = mock(IWorkflowService.class); - - PollDecisionTaskDispatcher dispatcher = new PollDecisionTaskDispatcher(mockService); - dispatcher.subscribe("tasklist1", handler); - - // Act - PollForDecisionTaskResponse response = - CreatePollForDecisionTaskResponse("I Don't Exist TaskList"); - dispatcher.process(response); - - // Assert - verify(mockService, times(1)).RespondDecisionTaskFailed(any()); - assertFalse(handled.get()); - assertEquals(1, appender.list.size()); - ILoggingEvent event = appender.list.get(0); - assertEquals(Level.WARN, event.getLevel()); - assertEquals( - String.format( - "No handler is subscribed for the PollForDecisionTaskResponse.WorkflowExecutionTaskList %s", - "I Don't Exist TaskList"), - event.getFormattedMessage()); - } - - private PollForDecisionTaskResponse CreatePollForDecisionTaskResponse(String taskListName) { - PollForDecisionTaskResponse response = new PollForDecisionTaskResponse(); - TaskList tl = new TaskList(); - tl.setName(taskListName); - response.setWorkflowExecutionTaskList(tl); - return response; - } -} diff --git a/src/test/java/com/uber/cadence/internal/worker/WorkflowPollTaskTest.java b/src/test/java/com/uber/cadence/internal/worker/WorkflowPollTaskTest.java index 2df3fae2d..55106ff12 100644 --- a/src/test/java/com/uber/cadence/internal/worker/WorkflowPollTaskTest.java +++ b/src/test/java/com/uber/cadence/internal/worker/WorkflowPollTaskTest.java @@ -31,6 +31,7 @@ import com.uber.m3.tally.Stopwatch; import com.uber.m3.tally.Timer; import com.uber.m3.util.Duration; +import java.util.concurrent.Semaphore; import org.junit.Before; import org.junit.Test; @@ -39,11 +40,13 @@ public class WorkflowPollTaskTest { private IWorkflowService mockService; private Scope mockMetricScope; private WorkflowPollTask pollTask; + private Semaphore semaphore; @Before public void setup() { mockService = mock(IWorkflowService.class); mockMetricScope = mock(Scope.class); + semaphore = new Semaphore(100); // Mock the Timer and Histogram for poll latency (dual-emit) Timer pollLatencyTimer = mock(Timer.class); @@ -77,15 +80,20 @@ public void setup() { when(mockMetricScope.counter(MetricsType.DECISION_POLL_SUCCEED_COUNTER)) .thenReturn(succeedCounter); + // Create a sticky queue balancer (disabled for test) + StickyQueueBalancer stickyQueueBalancer = new StickyQueueBalancer(1, false); + // Initialize pollTask with the mocked dependencies pollTask = new WorkflowPollTask( mockService, "test-domain", "test-taskList", - TaskListKind.NORMAL, + "test-taskList-sticky", mockMetricScope, - "test-identity"); + "test-identity", + semaphore, + stickyQueueBalancer); } @Test @@ -138,11 +146,11 @@ public void testPollSuccess() throws CadenceError { when(mockMetricScope.counter(MetricsType.DECISION_POLL_COUNTER)).thenReturn(pollCounter); when(taggedScope.counter(MetricsType.DECISION_POLL_SUCCEED_COUNTER)).thenReturn(succeedCounter); - PollForDecisionTaskResponse result = pollTask.poll(); + DecisionTask result = pollTask.poll(); // Verify that the result is not null and task token is as expected assertNotNull(result); - assertArrayEquals("testToken".getBytes(), result.getTaskToken()); + assertArrayEquals("testToken".getBytes(), result.getResponse().getTaskToken()); // Verify counter and timer/histogram behavior (dual-emit) verify(pollCounter, times(1)).inc(1); @@ -157,6 +165,11 @@ public void testPollSuccess() throws CadenceError { Duration.ofNanos(response.getStartedTimestamp() - response.getScheduledTimestamp()); verify(scheduledToStartLatencyTimer, times(1)).record(eq(expectedDuration)); verify(scheduledToStartLatencyHistogram, times(1)).recordDuration(eq(expectedDuration)); + + // Verify the completion callback releases the semaphore + int permitsBeforeCallback = semaphore.availablePermits(); + result.getCompletionCallback().apply(); + assertEquals(permitsBeforeCallback + 1, semaphore.availablePermits()); } @Test(expected = InternalServiceError.class) @@ -222,7 +235,7 @@ public void testPollNoTask() throws CadenceError { when(mockMetricScope.counter(MetricsType.DECISION_POLL_NO_TASK_COUNTER)) .thenReturn(noTaskCounter); - PollForDecisionTaskResponse result = pollTask.poll(); + DecisionTask result = pollTask.poll(); assertNull(result); verify(noTaskCounter, times(1)).inc(1);