Skip to content

Commit

Permalink
Introduce TerminalException (#145)
Browse files Browse the repository at this point in the history
* Introduce TerminalException
* Use TerminalException in Syscalls
* Use TerminalException in blocking interface and related protoc gen
* Use TerminalException in kotlin interface
* Fix the state machine to use TerminalException rather than StatusRuntimeException as error return value
* Fix tests
  • Loading branch information
slinkydeveloper authored Nov 20, 2023
1 parent 78eaa31 commit 4602ef0
Show file tree
Hide file tree
Showing 25 changed files with 259 additions and 176 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ public class {{className}} {
@java.lang.Deprecated
{{/deprecated}}
{{{javadoc}}}
public {{#isOutputEmpty}}void{{/isOutputEmpty}}{{^isOutputEmpty}}{{outputType}}{{/isOutputEmpty}} {{methodName}}(RestateContext context{{^isInputEmpty}}, {{inputType}} request{{/isInputEmpty}}) {
throw new io.grpc.StatusRuntimeException(io.grpc.Status.UNIMPLEMENTED);
public {{#isOutputEmpty}}void{{/isOutputEmpty}}{{^isOutputEmpty}}{{outputType}}{{/isOutputEmpty}} {{methodName}}(RestateContext context{{^isInputEmpty}}, {{inputType}} request{{/isInputEmpty}}) throws dev.restate.sdk.core.TerminalException {
throw new dev.restate.sdk.core.TerminalException(dev.restate.sdk.core.TerminalException.Code.UNIMPLEMENTED);
}

{{/methods}}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ public ReadyResultInternal<R> parseEntryResult(InvokeEntryMessage actual) {
if (actual.hasValue()) {
return valueParser.apply(actual.getValue());
}
return ReadyResults.failure(Util.toGrpcStatus(actual.getFailure()).asRuntimeException());
return ReadyResults.failure(Util.toRestateException(actual.getFailure()));
}

@Override
Expand All @@ -301,7 +301,7 @@ public ReadyResultInternal<R> parseCompletionResult(CompletionMessage actual) {
return valueParser.apply(actual.getValue());
}
if (actual.hasFailure()) {
return ReadyResults.failure(Util.toGrpcStatus(actual.getFailure()).asRuntimeException());
return ReadyResults.failure(Util.toRestateException(actual.getFailure()));
}
return super.parseCompletionResult(actual);
}
Expand Down Expand Up @@ -352,7 +352,7 @@ public ReadyResultInternal<ByteString> parseEntryResult(AwakeableEntryMessage ac
if (actual.hasValue()) {
return ReadyResults.success(actual.getValue());
}
return ReadyResults.failure(Util.toGrpcStatus(actual.getFailure()).asRuntimeException());
return ReadyResults.failure(Util.toRestateException(actual.getFailure()));
}

@Override
Expand All @@ -361,7 +361,7 @@ public ReadyResultInternal<ByteString> parseCompletionResult(CompletionMessage a
return ReadyResults.success(actual.getValue());
}
if (actual.hasFailure()) {
return ReadyResults.failure(Util.toGrpcStatus(actual.getFailure()).asRuntimeException());
return ReadyResults.failure(Util.toRestateException(actual.getFailure()));
}
return super.parseCompletionResult(actual);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.google.protobuf.ByteString;
import com.google.protobuf.MessageLite;
import dev.restate.sdk.core.TerminalException;
import dev.restate.sdk.core.syscalls.*;
import io.grpc.MethodDescriptor;
import java.time.Duration;
Expand Down Expand Up @@ -98,7 +99,7 @@ public <T extends MessageLite> void writeOutput(T value, SyscallCallback<Void> c
}

@Override
public void writeOutput(Throwable throwable, SyscallCallback<Void> callback) {
public void writeOutput(TerminalException throwable, SyscallCallback<Void> callback) {
syscallsExecutor.execute(() -> syscalls.writeOutput(throwable, callback));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.Status;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

Expand Down Expand Up @@ -67,7 +68,7 @@ private void closeWithException(Throwable e) {
serverCall.close(Util.SUSPENDED_STATUS, new Metadata());
} else {
LOG.warn("Error when processing the invocation", e);
serverCall.close(Util.toGrpcStatusWrappingUncaught(e), new Metadata());
serverCall.close(Status.UNKNOWN.withCause(e), new Metadata());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ public void onError(Throwable e) {
void completeSideEffectCallbackWithEntry(
Java.SideEffectEntryMessage sideEffectEntry, ExitSideEffectSyscallCallback callback) {
if (sideEffectEntry.hasFailure()) {
callback.onFailure(Util.toGrpcStatus(sideEffectEntry.getFailure()).asRuntimeException());
callback.onFailure(Util.toRestateException(sideEffectEntry.getFailure()));
} else {
callback.onResult(sideEffectEntry.getValue());
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package dev.restate.sdk.core.impl;

import dev.restate.sdk.core.TerminalException;
import dev.restate.sdk.core.syscalls.ReadyResult;
import io.grpc.StatusRuntimeException;
import java.util.function.Function;
import javax.annotation.Nullable;

Expand All @@ -18,7 +18,7 @@ static <T> ReadyResultInternal<T> success(T value) {
return new Success<>(value);
}

static <T> ReadyResultInternal<T> failure(StatusRuntimeException t) {
static <T> ReadyResultInternal<T> failure(TerminalException t) {
return new Failure<>(t);
}

Expand Down Expand Up @@ -54,7 +54,7 @@ public <U> ReadyResult<U> map(Function<T, U> mapper) {

@Nullable
@Override
public StatusRuntimeException getFailure() {
public TerminalException getFailure() {
return null;
}
}
Expand Down Expand Up @@ -89,15 +89,15 @@ public <U> ReadyResult<U> map(Function<T, U> mapper) {

@Nullable
@Override
public StatusRuntimeException getFailure() {
public TerminalException getFailure() {
return null;
}
}

static class Failure<T> implements ReadyResultInternal<T> {
private final StatusRuntimeException cause;
private final TerminalException cause;

private Failure(StatusRuntimeException cause) {
private Failure(TerminalException cause) {
this.cause = cause;
}

Expand Down Expand Up @@ -125,7 +125,7 @@ public <U> ReadyResult<U> map(Function<T, U> mapper) {

@Nullable
@Override
public StatusRuntimeException getFailure() {
public TerminalException getFailure() {
return cause;
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package dev.restate.sdk.core.impl;

import com.google.protobuf.MessageLite;
import dev.restate.sdk.core.TerminalException;
import dev.restate.sdk.core.syscalls.SyscallCallback;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
Expand Down Expand Up @@ -87,18 +88,22 @@ public void close(Status status, Metadata trailers) {
// Let's cancel the listener first
listener.onCancel();

if (status.getCause() instanceof UncaughtException) {
// This is the case where we have uncaught exceptions from GrpcServerCallListenerAdaptor
syscalls.fail(status.getCause().getCause());
} else {
if (Util.isTerminalException(status.getCause())) {
syscalls.writeOutput(
status.asRuntimeException(),
(TerminalException) status.getCause(),
SyscallCallback.ofVoid(
() -> {
LOG.trace("Closed correctly with non ok status {}", status);
LOG.trace("Closed correctly with non ok exception", status.getCause());
syscalls.close();
},
this::onError));
} else {
if (status.getCause() != null) {
syscalls.fail(status.getCause());
} else {
// Just propagate cause
syscalls.fail(status.asRuntimeException());
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import dev.restate.generated.sdk.java.Java;
import dev.restate.generated.service.protocol.Protocol;
import dev.restate.generated.service.protocol.Protocol.PollInputStreamEntryMessage;
import dev.restate.sdk.core.TerminalException;
import dev.restate.sdk.core.impl.DeferredResults.SingleDeferredResultInternal;
import dev.restate.sdk.core.impl.Entries.*;
import dev.restate.sdk.core.impl.ReadyResults.ReadyResultInternal;
Expand All @@ -35,7 +36,7 @@ public final class SyscallsImpl implements SyscallsInternal {

private final InvocationStateMachine stateMachine;

public SyscallsImpl(InvocationStateMachine stateMachine) {
SyscallsImpl(InvocationStateMachine stateMachine) {
this.stateMachine = stateMachine;
}

Expand All @@ -58,7 +59,7 @@ public <T extends MessageLite> void writeOutput(T value, SyscallCallback<Void> c
}

@Override
public void writeOutput(Throwable throwable, SyscallCallback<Void> callback) {
public void writeOutput(TerminalException throwable, SyscallCallback<Void> callback) {
LOG.trace("writeOutput failure");
this.writeOutput(
Protocol.OutputStreamEntryMessage.newBuilder()
Expand Down Expand Up @@ -177,14 +178,7 @@ public void exitSideEffectBlockWithException(
// If it's a non-terminal exception (such as a protocol exception),
// we don't write it but simply throw it
if (!isTerminalException(toWrite)) {
// For safety wrt Syscalls API we do this check and wrapping,
// but with the current APIs the exception should always be RuntimeException
// because that's what can be thrown inside a lambda
if (toWrite instanceof RuntimeException) {
throw (RuntimeException) toWrite;
} else {
throw new RuntimeException(toWrite);
}
Util.sneakyThrow(toWrite);
}

this.stateMachine.exitSideEffectBlock(
Expand Down

This file was deleted.

63 changes: 18 additions & 45 deletions sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/Util.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
import dev.restate.generated.sdk.java.Java;
import dev.restate.generated.service.protocol.Protocol;
import dev.restate.sdk.core.SuspendedException;
import dev.restate.sdk.core.TerminalException;
import io.grpc.Status;
import io.grpc.StatusException;
import io.grpc.StatusRuntimeException;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Predicate;
Expand All @@ -16,6 +15,11 @@ private Util() {}

static Status SUSPENDED_STATUS = Status.INTERNAL.withCause(SuspendedException.INSTANCE);

@SuppressWarnings("unchecked")
static <E extends Throwable> void sneakyThrow(Throwable e) throws E {
throw (E) e;
}

/**
* Finds a throwable fulfilling the condition in the cause chain of the given throwable. If there
* is none, then the method returns an empty optional.
Expand Down Expand Up @@ -44,20 +48,6 @@ static <T extends Throwable> Optional<T> findCause(
return Optional.empty();
}

public static Status toGrpcStatusWrappingUncaught(Throwable t) {
Throwable cause = Objects.requireNonNull(t);
while (cause != null) {
if (cause instanceof StatusException) {
return ((StatusException) cause).getStatus();
} else if (cause instanceof StatusRuntimeException) {
return ((StatusRuntimeException) cause).getStatus();
}
cause = cause.getCause();
}
// Couldn't find a cause with a Status
return Status.UNKNOWN.withCause(new UncaughtException(t));
}

public static Optional<ProtocolException> findProtocolException(Throwable throwable) {
return findCause(throwable, t -> t instanceof ProtocolException);
}
Expand All @@ -66,45 +56,28 @@ public static boolean containsSuspendedException(Throwable throwable) {
return findCause(throwable, t -> t == SuspendedException.INSTANCE).isPresent();
}

static Protocol.Failure toProtocolFailure(Status status) {
Protocol.Failure.Builder builder =
Protocol.Failure.newBuilder().setCode(status.getCode().value());
if (status.getDescription() != null) {
builder.setMessage(status.getDescription());
static Protocol.Failure toProtocolFailure(TerminalException.Code code, String message) {
Protocol.Failure.Builder builder = Protocol.Failure.newBuilder().setCode(code.value());
if (message != null) {
builder.setMessage(message);
}
return builder.build();
}

static Protocol.Failure toProtocolFailure(Throwable throwable) {
return toProtocolFailure(toGrpcStatusErasingCause(throwable));
}

static Status toGrpcStatus(Protocol.Failure failure) {
return Status.fromCodeValue(failure.getCode()).withDescription(failure.getMessage());
}

static Status toGrpcStatusErasingCause(Throwable throwable) {
Status status;
if (throwable instanceof StatusException) {
status = ((StatusException) throwable).getStatus();
} else if (throwable instanceof StatusRuntimeException) {
status = ((StatusRuntimeException) throwable).getStatus();
} else {
return Status.UNKNOWN.withDescription(throwable.getMessage());
if (throwable instanceof TerminalException) {
return toProtocolFailure(((TerminalException) throwable).getCode(), throwable.getMessage());
}
return toProtocolFailure(TerminalException.Code.UNKNOWN, throwable.toString());
}

// We erase the cause as it's not stored in the call result structure
// and can cause non-determinism.
//
// We can still set the error message though.
if (status.getDescription() == null && status.getCause() != null) {
status = status.withDescription(status.getCause().toString());
}
return status.withCause(null);
static TerminalException toRestateException(Protocol.Failure failure) {
return new TerminalException(
TerminalException.Code.fromValue(failure.getCode()), failure.getMessage());
}

static boolean isTerminalException(Throwable throwable) {
return throwable instanceof StatusRuntimeException || throwable instanceof StatusException;
return throwable instanceof TerminalException;
}

static void assertIsEntry(MessageLite msg) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
import dev.restate.generated.sdk.java.Java;
import dev.restate.generated.service.protocol.Protocol;
import dev.restate.generated.service.protocol.Protocol.StartMessage.StateEntry;
import dev.restate.sdk.core.TerminalException;
import dev.restate.sdk.core.impl.testservices.GreetingRequest;
import dev.restate.sdk.core.impl.testservices.GreetingResponse;
import io.grpc.MethodDescriptor;
import io.grpc.Status;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -74,7 +74,7 @@ public static Protocol.CompletionMessage completionMessage(
public static Protocol.CompletionMessage completionMessage(int index, Throwable e) {
return Protocol.CompletionMessage.newBuilder()
.setEntryIndex(index)
.setFailure(toProtocolFailure(Status.INTERNAL.withDescription(e.getMessage())))
.setFailure(toProtocolFailure(e))
.build();
}

Expand All @@ -98,16 +98,15 @@ public static Protocol.OutputStreamEntryMessage outputMessage(MessageLiteOrBuild
.build();
}

public static Protocol.OutputStreamEntryMessage outputMessage(Status s) {
public static Protocol.OutputStreamEntryMessage outputMessage(
TerminalException.Code code, String message) {
return Protocol.OutputStreamEntryMessage.newBuilder()
.setFailure(Util.toProtocolFailure(s.asRuntimeException()))
.setFailure(Util.toProtocolFailure(code, message))
.build();
}

public static Protocol.OutputStreamEntryMessage outputMessage(Throwable e) {
return Protocol.OutputStreamEntryMessage.newBuilder()
.setFailure(toProtocolFailure(Status.INTERNAL.withDescription(e.getMessage())))
.build();
return Protocol.OutputStreamEntryMessage.newBuilder().setFailure(toProtocolFailure(e)).build();
}

public static Protocol.GetStateEntryMessage.Builder getStateMessage(String key) {
Expand Down Expand Up @@ -156,9 +155,7 @@ Protocol.InvokeEntryMessage invokeMessage(
public static <T extends MessageLite, R extends MessageLite>
Protocol.InvokeEntryMessage invokeMessage(
MethodDescriptor<T, R> methodDescriptor, T parameter, Throwable e) {
return invokeMessage(methodDescriptor, parameter)
.setFailure(toProtocolFailure(Status.INTERNAL.withDescription(e.getMessage())))
.build();
return invokeMessage(methodDescriptor, parameter).setFailure(toProtocolFailure(e)).build();
}

public static <T extends MessageLite>
Expand Down
Loading

0 comments on commit 4602ef0

Please sign in to comment.