Skip to content
Merged
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
## v1.3.0
* Refactor `RetriableTask` and add new `CompoundTask`, fixing Fan-out/Fan-in stuck when using `RetriableTask` ([#157](https://github.com/microsoft/durabletask-java/pull/157))

## v1.2.0

### Updates
Expand Down
2 changes: 0 additions & 2 deletions client/src/main/java/com/microsoft/durabletask/Task.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import java.util.concurrent.CompletableFuture;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;

/**
* Represents an asynchronous operation in a durable orchestration.
Expand All @@ -32,7 +31,6 @@
*/
public abstract class Task<V> {
final CompletableFuture<V> future;

Task(CompletableFuture<V> future) {
this.future = future;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,19 @@ public interface TaskOrchestrationContext {
*/
boolean getIsReplaying();

/**
* Returns a new {@code Task} that is completed when all tasks in {@code tasks} completes.
* See {@link #allOf(Task[])} for more detailed information.
*
* @param tasks the list of {@code Task} objects
* @param <V> the return type of the {@code Task} objects
* @return a new {@code Task} that is completed when any of the given {@code Task}s complete
* @see #allOf(Task[])
*/
<V> Task<List<V>> allOf(List<Task<V>> tasks);

// TODO: Update the description of allOf to be more specific about the exception behavior.

// https://github.com/microsoft/durabletask-java/issues/54
/**
* Returns a new {@code Task} that is completed when all the given {@code Task}s complete. If any of the given
Expand All @@ -74,24 +86,26 @@ public interface TaskOrchestrationContext {
* Task<String> t2 = ctx.callActivity("MyActivity", String.class);
* Task<String> t3 = ctx.callActivity("MyActivity", String.class);
*
* List<String> orderedResults = ctx.allOf(List.of(t1, t2, t3)).await();
* List<String> orderedResults = ctx.allOf(t1, t2, t3).await();
* }</pre>
*
* Exceptions in any of the given tasks results in an unchecked {@link CompositeTaskFailedException}.
* This exception can be inspected to obtain failure details of individual {@link Task}s.
* <pre>{@code
* try {
* List<String> orderedResults = ctx.allOf(List.of(t1, t2, t3)).await();
* List<String> orderedResults = ctx.allOf(t1, t2, t3).await();
* } catch (CompositeTaskFailedException e) {
* List<Exception> exceptions = e.getExceptions()
* }
* }</pre>
*
* @param tasks the list of {@code Task} objects
* @param tasks the {@code Task}s
* @param <V> the return type of the {@code Task} objects
* @return the values of the completed {@code Task} objects in the same order as the source list
*/
<V> Task<List<V>> allOf(List<Task<V>> tasks);
default <V> Task<List<V>> allOf(Task<V>... tasks) {
return this.allOf(Arrays.asList(tasks));
}

/**
* Returns a new {@code Task} that is completed when any of the tasks in {@code tasks} completes.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,39 +185,43 @@ public <V> Task<List<V>> allOf(List<Task<V>> tasks) {
.map(t -> t.future)
.toArray((IntFunction<CompletableFuture<V>[]>) CompletableFuture[]::new);

return new CompletableTask<>(CompletableFuture.allOf(futures)
.thenApply(x -> {
List<V> results = new ArrayList<>(futures.length);

// All futures are expected to be completed at this point
for (CompletableFuture<V> cf : futures) {
try {
results.add(cf.get());
} catch (Exception ex) {
results.add(null);
}
}
return results;
})
.exceptionally(throwable -> {
ArrayList<Exception> exceptions = new ArrayList<>(futures.length);
for (CompletableFuture<V> cf : futures) {
try {
cf.get();
} catch (ExecutionException ex) {
exceptions.add((Exception) ex.getCause());
} catch (Exception ex){
exceptions.add(ex);
}
}
throw new CompositeTaskFailedException(
String.format(
"%d out of %d tasks failed with an exception. See the exceptions list for details.",
exceptions.size(),
futures.length),
exceptions);
})
);
Function<Void, List<V>> resultPath = x -> {
List<V> results = new ArrayList<>(futures.length);

// All futures are expected to be completed at this point
for (CompletableFuture<V> cf : futures) {
try {
results.add(cf.get());
} catch (Exception ex) {
results.add(null);
}
}
return results;
};

Function<Throwable, ? extends List<V>> exceptionPath = throwable -> {
ArrayList<Exception> exceptions = new ArrayList<>(futures.length);
for (CompletableFuture<V> cf : futures) {
try {
cf.get();
} catch (ExecutionException ex) {
exceptions.add((Exception) ex.getCause());
} catch (Exception ex) {
exceptions.add(ex);
}
}
throw new CompositeTaskFailedException(
String.format(
"%d out of %d tasks failed with an exception. See the exceptions list for details.",
exceptions.size(),
futures.length),
exceptions);
};
CompletableFuture<List<V>> future = CompletableFuture.allOf(futures)
.thenApply(resultPath)
.exceptionally(exceptionPath);

return new CompoundTask<>(tasks, future);
}

@Override
Expand All @@ -228,7 +232,7 @@ public Task<Task<?>> anyOf(List<Task<?>> tasks) {
.map(t -> t.future)
.toArray((IntFunction<CompletableFuture<?>[]>) CompletableFuture[]::new);

return new CompletableTask<>(CompletableFuture.anyOf(futures).thenApply(x -> {
CompletableFuture<Task<?>> future = CompletableFuture.anyOf(futures).thenApply(x -> {
// Return the first completed task in the list. Unlike the implementation in other languages,
// this might not necessarily be the first task that completed, so calling code shouldn't make
// assumptions about this. Note that changing this behavior later could be breaking.
Expand All @@ -240,7 +244,9 @@ public Task<Task<?>> anyOf(List<Task<?>> tasks) {

// Should never get here
return completedTask(null);
}));
});

return new CompoundTask(tasks, future);
}

@Override
Expand Down Expand Up @@ -971,9 +977,12 @@ private class RetriableTask<V> extends CompletableTask<V> {
private final Instant firstAttempt;
private final TaskFactory<V> taskFactory;

private int attemptNumber;
private FailureDetails lastFailure;
private Duration totalRetryTime;
private Instant startTime;
private int attemptNumber;
private Task<V> childTask;


public RetriableTask(TaskOrchestrationContext context, TaskFactory<V> taskFactory, RetryPolicy policy) {
this(context, taskFactory, policy, null);
Expand All @@ -988,45 +997,88 @@ private RetriableTask(
TaskFactory<V> taskFactory,
@Nullable RetryPolicy retryPolicy,
@Nullable RetryHandler retryHandler) {
super(new CompletableFuture<>());
this.context = context;
this.taskFactory = taskFactory;
this.policy = retryPolicy;
this.handler = retryHandler;
this.firstAttempt = context.getCurrentInstant();
this.totalRetryTime = Duration.ZERO;
this.createChildTask(taskFactory);
}

@Override
public V await() {
Instant startTime = this.context.getCurrentInstant();
while (true) {
Task<V> currentTask = this.taskFactory.create();
// Every RetriableTask will have a CompletableTask as a child task.
private void createChildTask(TaskFactory<V> taskFactory) {
CompletableTask<V> childTask = (CompletableTask<V>) taskFactory.create();
this.setChildTask(childTask);
childTask.setParentTask(this);
}

this.attemptNumber++;
public void setChildTask(Task<V> childTask) {
this.childTask = childTask;
}

try {
return currentTask.await();
} catch (TaskFailedException ex) {
this.lastFailure = ex.getErrorDetails();
if (!this.shouldRetry()) {
throw ex;
}
public Task<V> getChildTask() {
return this.childTask;
}

// Overflow/runaway retry protection
if (this.attemptNumber == Integer.MAX_VALUE) {
throw ex;
}
}
void handleChildSuccess(V result) {
this.complete(result);
}

Duration delay = this.getNextDelay();
if (!delay.isZero() && !delay.isNegative()) {
// Use a durable timer to create the delay between retries
this.context.createTimer(delay).await();
}
void handleChildException(Throwable ex) {
tryRetry((TaskFailedException) ex);
}

this.totalRetryTime = Duration.between(startTime, this.context.getCurrentInstant());
void init() {
this.startTime = this.startTime == null ? this.context.getCurrentInstant() : this.startTime;
this.attemptNumber++;
}

public void tryRetry(TaskFailedException ex) {
this.lastFailure = ex.getErrorDetails();
if (!this.shouldRetry()) {
this.completeExceptionally(ex);
return;
}

// Overflow/runaway retry protection
if (this.attemptNumber == Integer.MAX_VALUE) {
this.completeExceptionally(ex);
return;
}

Duration delay = this.getNextDelay();
if (!delay.isZero() && !delay.isNegative()) {
// Use a durable timer to create the delay between retries
this.context.createTimer(delay).await();
}

this.totalRetryTime = Duration.between(this.startTime, this.context.getCurrentInstant());
this.createChildTask(this.taskFactory);
this.await();
}

@Override
public V await() {
this.init();
// when awaiting the first child task, we will continue iterating over the history until a result is found
// for that task. If the result is an exception, the child task will invoke "handleChildException" on this
// object, which awaits a timer, *re-sets the current child task to correspond to a retry of this task*,
// and then awaits that child.
// This logic continues until either the operation succeeds, or are our retry quota is met.
// At that point, we break the `await()` on the child task.
// Therefore, once we return from the following `await`,
// we just need to await again on the *current* child task to obtain the result of this task
try{
this.getChildTask().await();
} catch (OrchestratorBlockedException ex) {
throw ex;
} catch (Exception ignored) {
// ignore the exception from previous child tasks.
// Only needs to return result from the last child task, which is on next line.
}
// Always return the last child task result.
return this.getChildTask().await();
}

private boolean shouldRetry() {
Expand Down Expand Up @@ -1101,7 +1153,30 @@ private Duration getNextDelay() {
}
}

private class CompoundTask<V, U> extends CompletableTask<U> {

List<Task<V>> subTasks;

CompoundTask(List<Task<V>> subtasks, CompletableFuture<U> future) {
super(future);
this.subTasks = subtasks;
}

@Override
public U await() {
this.initSubTasks();
return super.await();
}

private void initSubTasks() {
for (Task<V> subTask : this.subTasks) {
if (subTask instanceof RetriableTask) ((RetriableTask<V>)subTask).init();
}
}
}

private class CompletableTask<V> extends Task<V> {
private Task<V> parentTask;

public CompletableTask() {
this(new CompletableFuture<>());
Expand All @@ -1111,6 +1186,14 @@ public CompletableTask() {
super(future);
}

public void setParentTask(Task<V> parentTask) {
this.parentTask = parentTask;
}

public Task<V> getParentTask() {
return this.parentTask;
}

@Override
public V await() {
do {
Expand Down Expand Up @@ -1168,15 +1251,27 @@ public boolean isDone() {
}

public boolean complete(V value) {
return this.future.complete(value);
Task<V> parentTask = this.getParentTask();
boolean result = this.future.complete(value);
if (parentTask instanceof RetriableTask) {
// notify parent task
((RetriableTask<V>) parentTask).handleChildSuccess(value);
}
return result;
}

private boolean cancel() {
return this.future.cancel(true);
}

public boolean completeExceptionally(Throwable ex) {
return this.future.completeExceptionally(ex);
Task<V> parentTask = this.getParentTask();
boolean result = this.future.completeExceptionally(ex);
if (parentTask instanceof RetriableTask) {
// notify parent task
((RetriableTask<V>) parentTask).handleChildException(ex);
}
return result;
}
}
}
Expand Down
Loading