From 8223cd663d775a590e56560c97d38e9872be1a35 Mon Sep 17 00:00:00 2001 From: Francesco Guardiani Date: Tue, 17 Oct 2023 16:43:46 +0200 Subject: [PATCH] Add Kotlin tests (#117) * Split test suite code and test services. With this change we can now reuse the same tests for the kotlin interface. * Add tests to kotlin interface * A bunch of fixes for issues that came up while testing the kotlin sdk. --- sdk-core-impl/build.gradle.kts | 1 - .../restate/sdk/core/impl/SyscallsImpl.java | 4 +- ...eIdTest.java => AwakeableIdTestSuite.java} | 23 +- .../restate/sdk/core/impl/CoreTestRunner.java | 109 +++++---- ...ferredTest.java => DeferredTestSuite.java} | 220 +++--------------- ...tateTest.java => EagerStateTestSuite.java} | 91 ++------ .../dev/restate/sdk/core/impl/FlowUtils.java | 39 ---- ...Test.java => GetAndSetStateTestSuite.java} | 55 +---- ...tStateTest.java => GetStateTestSuite.java} | 33 +-- ...IdTest.java => InvocationIdTestSuite.java} | 21 +- ....java => OnlyInputAndOutputTestSuite.java} | 17 +- .../dev/restate/sdk/core/impl/ProtoUtils.java | 42 ++-- .../restate/sdk/core/impl/SideEffectTest.java | 185 --------------- .../sdk/core/impl/SideEffectTestSuite.java | 95 ++++++++ .../{SleepTest.java => SleepTestSuite.java} | 57 +---- .../core/impl/StateMachineFailuresTest.java | 122 ---------- .../impl/StateMachineFailuresTestSuite.java | 78 +++++++ .../sdk/core/impl/UserFailuresTest.java | 122 ---------- .../sdk/core/impl/UserFailuresTestSuite.java | 77 ++++++ sdk-java-blocking/build.gradle.kts | 28 ++- .../restate/sdk/blocking/RestateContext.java | 2 +- .../restate/sdk/blocking/AwakeableIdTest.java | 29 +++ .../restate/sdk/blocking/DeferredTest.java | 212 +++++++++++++++++ .../restate/sdk/blocking/EagerStateTest.java | 93 ++++++++ .../sdk/blocking/GetAndSetStateTest.java | 62 +++++ .../restate/sdk/blocking/GetStateTest.java | 30 +++ .../sdk/blocking/InvocationIdTest.java | 29 +++ .../sdk/blocking/OnlyInputAndOutputTest.java | 25 ++ .../restate/sdk/blocking/SideEffectTest.java | 136 +++++++++++ .../dev/restate/sdk/blocking/SleepTest.java | 61 +++++ .../blocking/StateMachineFailuresTest.java | 61 +++++ .../sdk/blocking/UserFailuresTest.java | 126 ++++++++++ sdk-kotlin/build.gradle.kts | 42 +++- .../restate/sdk/kotlin/RestateContextImpl.kt | 31 ++- .../main/kotlin/dev/restate/sdk/kotlin/api.kt | 10 +- .../dev/restate/sdk/kotlin/AwakeableIdTest.kt | 25 ++ .../dev/restate/sdk/kotlin/DeferredTest.kt | 136 +++++++++++ .../dev/restate/sdk/kotlin/EagerStateTest.kt | 70 ++++++ .../restate/sdk/kotlin/GetAndSetStateTest.kt | 33 +++ .../dev/restate/sdk/kotlin/GetStateTest.kt | 26 +++ .../restate/sdk/kotlin/InvocationIdTest.kt | 23 ++ .../sdk/kotlin/OnlyInputAndOutputTest.kt | 22 ++ .../dev/restate/sdk/kotlin/SideEffectTest.kt | 87 +++++++ .../dev/restate/sdk/kotlin/SleepTest.kt | 42 ++++ .../sdk/kotlin/StateMachineFailuresTest.kt | 48 ++++ .../restate/sdk/kotlin/UserFailuresTest.kt | 59 +++++ 46 files changed, 1968 insertions(+), 971 deletions(-) rename sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/{AwakeableIdTest.java => AwakeableIdTestSuite.java} (78%) rename sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/{DeferredTest.java => DeferredTestSuite.java} (69%) rename sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/{EagerStateTest.java => EagerStateTestSuite.java} (55%) rename sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/{GetAndSetStateTest.java => GetAndSetStateTestSuite.java} (53%) rename sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/{GetStateTest.java => GetStateTestSuite.java} (72%) rename sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/{InvocationIdTest.java => InvocationIdTestSuite.java} (56%) rename sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/{OnlyInputAndOutputTest.java => OnlyInputAndOutputTestSuite.java} (57%) delete mode 100644 sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/SideEffectTest.java create mode 100644 sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/SideEffectTestSuite.java rename sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/{SleepTest.java => SleepTestSuite.java} (66%) delete mode 100644 sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/StateMachineFailuresTest.java create mode 100644 sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/StateMachineFailuresTestSuite.java delete mode 100644 sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/UserFailuresTest.java create mode 100644 sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/UserFailuresTestSuite.java create mode 100644 sdk-java-blocking/src/test/java/dev/restate/sdk/blocking/AwakeableIdTest.java create mode 100644 sdk-java-blocking/src/test/java/dev/restate/sdk/blocking/DeferredTest.java create mode 100644 sdk-java-blocking/src/test/java/dev/restate/sdk/blocking/EagerStateTest.java create mode 100644 sdk-java-blocking/src/test/java/dev/restate/sdk/blocking/GetAndSetStateTest.java create mode 100644 sdk-java-blocking/src/test/java/dev/restate/sdk/blocking/GetStateTest.java create mode 100644 sdk-java-blocking/src/test/java/dev/restate/sdk/blocking/InvocationIdTest.java create mode 100644 sdk-java-blocking/src/test/java/dev/restate/sdk/blocking/OnlyInputAndOutputTest.java create mode 100644 sdk-java-blocking/src/test/java/dev/restate/sdk/blocking/SideEffectTest.java create mode 100644 sdk-java-blocking/src/test/java/dev/restate/sdk/blocking/SleepTest.java create mode 100644 sdk-java-blocking/src/test/java/dev/restate/sdk/blocking/StateMachineFailuresTest.java create mode 100644 sdk-java-blocking/src/test/java/dev/restate/sdk/blocking/UserFailuresTest.java create mode 100644 sdk-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/AwakeableIdTest.kt create mode 100644 sdk-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/DeferredTest.kt create mode 100644 sdk-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/EagerStateTest.kt create mode 100644 sdk-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/GetAndSetStateTest.kt create mode 100644 sdk-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/GetStateTest.kt create mode 100644 sdk-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/InvocationIdTest.kt create mode 100644 sdk-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/OnlyInputAndOutputTest.kt create mode 100644 sdk-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/SideEffectTest.kt create mode 100644 sdk-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/SleepTest.kt create mode 100644 sdk-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/StateMachineFailuresTest.kt create mode 100644 sdk-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/UserFailuresTest.kt diff --git a/sdk-core-impl/build.gradle.kts b/sdk-core-impl/build.gradle.kts index 6e499ca8..889a7e05 100644 --- a/sdk-core-impl/build.gradle.kts +++ b/sdk-core-impl/build.gradle.kts @@ -30,7 +30,6 @@ dependencies { testCompileOnly(coreLibs.javax.annotation.api) - testImplementation(project(":sdk-java-blocking")) testImplementation(testingLibs.junit.jupiter) testImplementation(testingLibs.assertj) testImplementation(coreLibs.grpc.stub) diff --git a/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/SyscallsImpl.java b/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/SyscallsImpl.java index a11217ab..c89332c0 100644 --- a/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/SyscallsImpl.java +++ b/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/SyscallsImpl.java @@ -339,7 +339,9 @@ private T deserialize(TypeTag ty, ByteString bytes) { return (T) bytes.toByteArray(); } else if (ByteString.class.equals(typeTag)) { return (T) bytes; - } else if (Void.class.equals(typeTag)) { + } else if (Void.class.equals(typeTag) || Void.TYPE.equals(typeTag)) { + // Amazing JVM foot-gun here: Void.TYPE is the primitive type, Void.class is the boxed type. + // For us, they're the same but for the equality they aren't, so we check both return null; } return serde.deserialize(typeTag, bytes.toByteArray()); diff --git a/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/AwakeableIdTest.java b/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/AwakeableIdTestSuite.java similarity index 78% rename from sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/AwakeableIdTest.java rename to sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/AwakeableIdTestSuite.java index dea2e2ae..a94db206 100644 --- a/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/AwakeableIdTest.java +++ b/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/AwakeableIdTestSuite.java @@ -1,39 +1,28 @@ package dev.restate.sdk.core.impl; import static dev.restate.sdk.core.impl.CoreTestRunner.TestCaseBuilder.testInvocation; -import static dev.restate.sdk.core.impl.ProtoUtils.*; +import static dev.restate.sdk.core.impl.ProtoUtils.inputMessage; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.InstanceOfAssertFactories.type; import com.google.protobuf.ByteString; import com.google.protobuf.InvalidProtocolBufferException; import dev.restate.generated.service.protocol.Protocol; -import dev.restate.sdk.blocking.RestateBlockingService; -import dev.restate.sdk.core.TypeTag; import dev.restate.sdk.core.impl.testservices.GreeterGrpc; import dev.restate.sdk.core.impl.testservices.GreetingRequest; import dev.restate.sdk.core.impl.testservices.GreetingResponse; -import io.grpc.stub.StreamObserver; +import io.grpc.BindableService; import java.nio.ByteBuffer; import java.util.Base64; import java.util.UUID; import java.util.stream.Stream; -class AwakeableIdTest extends CoreTestRunner { +public abstract class AwakeableIdTestSuite extends CoreTestRunner { - private static class ReturnAwakeableId extends GreeterGrpc.GreeterImplBase - implements RestateBlockingService { - - @Override - public void greet(GreetingRequest request, StreamObserver responseObserver) { - String id = restateContext().awakeable(TypeTag.STRING_UTF8).id(); - responseObserver.onNext(greetingResponse(id)); - responseObserver.onCompleted(); - } - } + protected abstract BindableService returnAwakeableId(); @Override - Stream definitions() { + protected Stream definitions() { UUID id = UUID.randomUUID(); String debugId = id.toString(); byte[] serializedId = serializeUUID(id); @@ -46,7 +35,7 @@ Stream definitions() { Base64.getUrlEncoder().encodeToString(expectedAwakeableId.array()); return Stream.of( - testInvocation(new ReturnAwakeableId(), GreeterGrpc.getGreetMethod()) + testInvocation(this::returnAwakeableId, GreeterGrpc.getGreetMethod()) .withInput( Protocol.StartMessage.newBuilder() .setDebugId(debugId) diff --git a/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/CoreTestRunner.java b/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/CoreTestRunner.java index 735cad6b..68ebbba9 100644 --- a/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/CoreTestRunner.java +++ b/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/CoreTestRunner.java @@ -1,6 +1,5 @@ package dev.restate.sdk.core.impl; -import static dev.restate.sdk.core.impl.CoreTestRunner.TestCaseBuilder.testInvocation; import static dev.restate.sdk.core.impl.ProtoUtils.headerFromMessage; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.params.provider.Arguments.arguments; @@ -17,29 +16,29 @@ import io.grpc.MethodDescriptor; import io.grpc.ServerServiceDefinition; import java.time.Duration; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashSet; -import java.util.List; +import java.util.*; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; import java.util.concurrent.Executors; import java.util.function.BiConsumer; import java.util.function.Consumer; +import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.Stream; +import javax.annotation.Nullable; import org.junit.jupiter.api.TestInstance; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; @TestInstance(TestInstance.Lifecycle.PER_CLASS) -abstract class CoreTestRunner { +public abstract class CoreTestRunner { - abstract Stream definitions(); + protected abstract Stream definitions(); Stream source() { return definitions() + .filter(TestDefinition::isValid) .flatMap( c -> c.getThreadingModels().stream() @@ -125,12 +124,12 @@ void executeTest( } } - enum ThreadingModel { + public enum ThreadingModel { BUFFERED_SINGLE_THREAD, UNBUFFERED_MULTI_THREAD } - interface TestDefinition { + public interface TestDefinition { ServerServiceDefinition getService(); String getMethod(); @@ -142,30 +141,41 @@ interface TestDefinition { BiConsumer, Duration> getOutputAssert(); String testCaseName(); + + boolean isValid(); } /** Builder for the test cases */ - static class TestCaseBuilder { + public static class TestCaseBuilder { - static TestInvocationBuilder testInvocation(BindableService svc, String method) { + public static TestInvocationBuilder testInvocation(BindableService svc, String method) { return new TestInvocationBuilder(svc, method); } - static TestInvocationBuilder testInvocation( + public static TestInvocationBuilder testInvocation( BindableService svc, MethodDescriptor method) { return testInvocation(svc, method.getBareMethodName()); } - static class TestInvocationBuilder { - protected final BindableService svc; + public static TestInvocationBuilder testInvocation( + Supplier svc, MethodDescriptor method) { + try { + return testInvocation(svc.get(), method.getBareMethodName()); + } catch (UnsupportedOperationException e) { + return testInvocation(null, method.getBareMethodName()); + } + } + + public static class TestInvocationBuilder { + protected final @Nullable BindableService svc; protected final String method; - TestInvocationBuilder(BindableService svc, String method) { + TestInvocationBuilder(@Nullable BindableService svc, String method) { this.svc = svc; this.method = method; } - WithInputBuilder withInput(short flags, MessageLiteOrBuilder msgOrBuilder) { + public WithInputBuilder withInput(short flags, MessageLiteOrBuilder msgOrBuilder) { MessageLite msg = ProtoUtils.build(msgOrBuilder); return new WithInputBuilder( svc, @@ -173,7 +183,7 @@ WithInputBuilder withInput(short flags, MessageLiteOrBuilder msgOrBuilder) { List.of(InvocationInput.of(headerFromMessage(msg).copyWithFlags(flags), msg))); } - WithInputBuilder withInput(MessageLiteOrBuilder... messages) { + public WithInputBuilder withInput(MessageLiteOrBuilder... messages) { return new WithInputBuilder( svc, method, @@ -187,7 +197,7 @@ WithInputBuilder withInput(MessageLiteOrBuilder... messages) { } } - static class WithInputBuilder extends TestInvocationBuilder { + public static class WithInputBuilder extends TestInvocationBuilder { private final List input; WithInputBuilder(BindableService svc, String method, List input) { @@ -195,13 +205,13 @@ static class WithInputBuilder extends TestInvocationBuilder { this.input = new ArrayList<>(input); } - WithInputBuilder withInput(short flags, MessageLiteOrBuilder msgOrBuilder) { + public WithInputBuilder withInput(short flags, MessageLiteOrBuilder msgOrBuilder) { MessageLite msg = ProtoUtils.build(msgOrBuilder); this.input.add(InvocationInput.of(headerFromMessage(msg).copyWithFlags(flags), msg)); return this; } - WithInputBuilder withInput(MessageLiteOrBuilder... messages) { + public WithInputBuilder withInput(MessageLiteOrBuilder... messages) { this.input.addAll( Arrays.stream(messages) .map( @@ -213,18 +223,18 @@ WithInputBuilder withInput(MessageLiteOrBuilder... messages) { return this; } - UsingThreadingModelsBuilder usingThreadingModels(ThreadingModel... threadingModels) { + public UsingThreadingModelsBuilder usingThreadingModels(ThreadingModel... threadingModels) { return new UsingThreadingModelsBuilder( this.svc, this.method, input, new HashSet<>(Arrays.asList(threadingModels))); } - UsingThreadingModelsBuilder usingAllThreadingModels() { + public UsingThreadingModelsBuilder usingAllThreadingModels() { return usingThreadingModels(ThreadingModel.values()); } } - static class UsingThreadingModelsBuilder { - private final BindableService svc; + public static class UsingThreadingModelsBuilder { + private final @Nullable BindableService svc; private final String method; private final List input; private final HashSet threadingModels; @@ -240,42 +250,47 @@ static class UsingThreadingModelsBuilder { this.threadingModels = threadingModels; } - ExpectingOutputMessages expectingOutput(MessageLiteOrBuilder... messages) { + public ExpectingOutputMessages expectingOutput(MessageLiteOrBuilder... messages) { List builtMessages = Arrays.stream(messages).map(ProtoUtils::build).collect(Collectors.toList()); return assertingOutput(actual -> assertThat(actual).asList().isEqualTo(builtMessages)); } - ExpectingOutputMessages assertingOutput(Consumer> messages) { + public ExpectingOutputMessages assertingOutput(Consumer> messages) { return new ExpectingOutputMessages(svc, method, input, threadingModels, messages); } - ExpectingFailure assertingFailure(Class tClass) { + public ExpectingFailure assertingFailure(Class tClass) { return assertingFailure(t -> assertThat(t).isInstanceOf(tClass)); } - ExpectingFailure assertingFailure(Consumer assertFailure) { + public ExpectingFailure assertingFailure(Consumer assertFailure) { return new ExpectingFailure(svc, method, input, threadingModels, assertFailure); } } public abstract static class BaseTestDefinition implements TestDefinition { - protected final BindableService svc; + protected final @Nullable BindableService svc; protected final String method; protected final List input; protected final HashSet threadingModels; protected final String named; public BaseTestDefinition( - BindableService svc, + @Nullable BindableService svc, String method, List input, HashSet threadingModels) { - this(svc, method, input, threadingModels, svc.getClass().getSimpleName()); + this( + svc, + method, + input, + threadingModels, + svc != null ? svc.getClass().getSimpleName() : "invalid"); } public BaseTestDefinition( - BindableService svc, + @Nullable BindableService svc, String method, List input, HashSet threadingModels, @@ -289,7 +304,7 @@ public BaseTestDefinition( @Override public ServerServiceDefinition getService() { - return svc.bindService(); + return Objects.requireNonNull(svc).bindService(); } @Override @@ -313,11 +328,11 @@ public String testCaseName() { } } - static class ExpectingOutputMessages extends BaseTestDefinition { + public static class ExpectingOutputMessages extends BaseTestDefinition { private final Consumer> messagesAssert; ExpectingOutputMessages( - BindableService svc, + @Nullable BindableService svc, String method, List input, HashSet threadingModels, @@ -327,7 +342,7 @@ static class ExpectingOutputMessages extends BaseTestDefinition { } ExpectingOutputMessages( - BindableService svc, + @Nullable BindableService svc, String method, List input, HashSet threadingModels, @@ -337,14 +352,14 @@ static class ExpectingOutputMessages extends BaseTestDefinition { this.messagesAssert = messagesAssert; } - ExpectingOutputMessages named(String name) { + public ExpectingOutputMessages named(String name) { return new TestCaseBuilder.ExpectingOutputMessages( svc, method, input, threadingModels, messagesAssert, - svc.getClass().getSimpleName() + ": " + name); + svc != null ? (svc.getClass().getSimpleName() + ": " + name) : "invalid"); } @Override @@ -366,13 +381,18 @@ public BiConsumer, Duration> getOutputAssert() { Protocol.ErrorMessage.class); }; } + + @Override + public boolean isValid() { + return this.svc != null; + } } - static class ExpectingFailure extends BaseTestDefinition { + public static class ExpectingFailure extends BaseTestDefinition { private final Consumer throwableAssert; ExpectingFailure( - BindableService svc, + @Nullable BindableService svc, String method, List input, HashSet threadingModels, @@ -392,14 +412,14 @@ static class ExpectingFailure extends BaseTestDefinition { this.throwableAssert = throwableAssert; } - ExpectingFailure named(String name) { + public ExpectingFailure named(String name) { return new ExpectingFailure( svc, method, input, threadingModels, throwableAssert, - svc.getClass().getSimpleName() + ": " + name); + svc != null ? (svc.getClass().getSimpleName() + ": " + name) : "invalid"); } @Override @@ -416,6 +436,11 @@ public BiConsumer, Duration> getOutputAssert() { Protocol.OutputStreamEntryMessage.class, Protocol.SuspensionMessage.class); }; } + + @Override + public boolean isValid() { + return this.svc != null; + } } } } diff --git a/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/DeferredTest.java b/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/DeferredTestSuite.java similarity index 69% rename from sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/DeferredTest.java rename to sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/DeferredTestSuite.java index 637aefa8..b2925f0c 100644 --- a/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/DeferredTest.java +++ b/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/DeferredTestSuite.java @@ -9,180 +9,34 @@ import com.google.protobuf.Empty; import dev.restate.generated.sdk.java.Java; import dev.restate.generated.service.protocol.Protocol; -import dev.restate.sdk.blocking.Awaitable; -import dev.restate.sdk.blocking.RestateBlockingService; -import dev.restate.sdk.blocking.RestateContext; -import dev.restate.sdk.core.StateKey; -import dev.restate.sdk.core.TypeTag; import dev.restate.sdk.core.impl.testservices.GreeterGrpc; import dev.restate.sdk.core.impl.testservices.GreetingRequest; -import dev.restate.sdk.core.impl.testservices.GreetingResponse; -import io.grpc.stub.StreamObserver; -import java.time.Duration; -import java.util.concurrent.TimeoutException; +import io.grpc.BindableService; import java.util.stream.Stream; -public class DeferredTest extends CoreTestRunner { +public abstract class DeferredTestSuite extends CoreTestRunner { - private static class ReverseAwaitOrder extends GreeterGrpc.GreeterImplBase - implements RestateBlockingService { - @Override - public void greet(GreetingRequest request, StreamObserver responseObserver) { - RestateContext ctx = restateContext(); + protected abstract BindableService reverseAwaitOrder(); - Awaitable a1 = - ctx.call(GreeterGrpc.getGreetMethod(), greetingRequest("Francesco")); - Awaitable a2 = - ctx.call(GreeterGrpc.getGreetMethod(), greetingRequest("Till")); + protected abstract BindableService awaitTwiceTheSameAwaitable(); - String a2Res = a2.await().getMessage(); - ctx.set(StateKey.of("A2", TypeTag.STRING_UTF8), a2Res); + protected abstract BindableService awaitAll(); - String a1Res = a1.await().getMessage(); + protected abstract BindableService awaitAny(); - responseObserver.onNext(greetingResponse(a1Res + "-" + a2Res)); - responseObserver.onCompleted(); - } - } - - private static class AwaitTwiceTheSameAwaitable extends GreeterGrpc.GreeterImplBase - implements RestateBlockingService { - @Override - public void greet(GreetingRequest request, StreamObserver responseObserver) { - RestateContext ctx = restateContext(); - - Awaitable a = - ctx.call(GreeterGrpc.getGreetMethod(), greetingRequest("Francesco")); - - responseObserver.onNext( - greetingResponse(a.await().getMessage() + "-" + a.await().getMessage())); - responseObserver.onCompleted(); - } - } - - private static class AwaitAll extends GreeterGrpc.GreeterImplBase - implements RestateBlockingService { - @Override - public void greet(GreetingRequest request, StreamObserver responseObserver) { - RestateContext ctx = restateContext(); - - Awaitable a1 = - ctx.call(GreeterGrpc.getGreetMethod(), greetingRequest("Francesco")); - Awaitable a2 = - ctx.call(GreeterGrpc.getGreetMethod(), greetingRequest("Till")); - - Awaitable.all(a1, a2).await(); - - responseObserver.onNext( - greetingResponse(a1.await().getMessage() + "-" + a2.await().getMessage())); - responseObserver.onCompleted(); - } - } - - private static class AwaitAny extends GreeterGrpc.GreeterImplBase - implements RestateBlockingService { - @Override - public void greet(GreetingRequest request, StreamObserver responseObserver) { - RestateContext ctx = restateContext(); - - Awaitable a1 = - ctx.call(GreeterGrpc.getGreetMethod(), greetingRequest("Francesco")); - Awaitable a2 = - ctx.call(GreeterGrpc.getGreetMethod(), greetingRequest("Till")); - - GreetingResponse res = (GreetingResponse) Awaitable.any(a1, a2).await(); - - responseObserver.onNext(res); - responseObserver.onCompleted(); - } - } - - private static class CombineAnyWithAll extends GreeterGrpc.GreeterImplBase - implements RestateBlockingService { - @Override - public void greet(GreetingRequest request, StreamObserver responseObserver) { - RestateContext ctx = restateContext(); + protected abstract BindableService combineAnyWithAll(); - Awaitable a1 = ctx.awakeable(TypeTag.STRING_UTF8); - Awaitable a2 = ctx.awakeable(TypeTag.STRING_UTF8); - Awaitable a3 = ctx.awakeable(TypeTag.STRING_UTF8); - Awaitable a4 = ctx.awakeable(TypeTag.STRING_UTF8); + protected abstract BindableService awaitAnyIndex(); - Awaitable a12 = Awaitable.any(a1, a2); - Awaitable a23 = Awaitable.any(a2, a3); - Awaitable a34 = Awaitable.any(a3, a4); - Awaitable.all(a12, a23, a34).await(); + protected abstract BindableService awaitOnAlreadyResolvedAwaitables(); - responseObserver.onNext(greetingResponse(a12.await() + (String) a23.await() + a34.await())); - responseObserver.onCompleted(); - } - } - - private static class AnyAwaitIndex extends GreeterGrpc.GreeterImplBase - implements RestateBlockingService { - @Override - public void greet(GreetingRequest request, StreamObserver responseObserver) { - RestateContext ctx = restateContext(); - - Awaitable a1 = ctx.awakeable(TypeTag.STRING_UTF8); - Awaitable a2 = ctx.awakeable(TypeTag.STRING_UTF8); - Awaitable a3 = ctx.awakeable(TypeTag.STRING_UTF8); - Awaitable a4 = ctx.awakeable(TypeTag.STRING_UTF8); - - responseObserver.onNext( - greetingResponse( - String.valueOf(Awaitable.any(a1, Awaitable.all(a2, a3), a4).awaitIndex()))); - responseObserver.onCompleted(); - } - } - - private static class AwaitOnAlreadyResolvedAwaitables extends GreeterGrpc.GreeterImplBase - implements RestateBlockingService { - @Override - public void greet(GreetingRequest request, StreamObserver responseObserver) { - RestateContext ctx = restateContext(); - - Awaitable a1 = ctx.awakeable(TypeTag.STRING_UTF8); - Awaitable a2 = ctx.awakeable(TypeTag.STRING_UTF8); - - Awaitable a12 = Awaitable.all(a1, a2); - Awaitable a12and1 = Awaitable.all(a12, a1); - Awaitable a121and12 = Awaitable.all(a12and1, a12); - - a12and1.await(); - a121and12.await(); - - responseObserver.onNext(greetingResponse(a1.await() + a2.await())); - responseObserver.onCompleted(); - } - } - - private static class AwaitWithTimeout extends GreeterGrpc.GreeterImplBase - implements RestateBlockingService { - @Override - public void greet(GreetingRequest request, StreamObserver responseObserver) { - RestateContext ctx = restateContext(); - - Awaitable call = - ctx.call(GreeterGrpc.getGreetMethod(), greetingRequest("Francesco")); - - String result; - try { - result = call.await(Duration.ofDays(1)).getMessage(); - } catch (TimeoutException e) { - result = "timeout"; - } - - responseObserver.onNext(greetingResponse(result)); - responseObserver.onCompleted(); - } - } + protected abstract BindableService awaitWithTimeout(); @Override - Stream definitions() { + protected Stream definitions() { return Stream.of( // --- Reverse await order - testInvocation(new ReverseAwaitOrder(), GreeterGrpc.getGreetMethod()) + testInvocation(this::reverseAwaitOrder, GreeterGrpc.getGreetMethod()) .withInput(startMessage(1), inputMessage(GreetingRequest.newBuilder())) .usingAllThreadingModels() .expectingOutput( @@ -190,7 +44,7 @@ Stream definitions() { invokeMessage(GreeterGrpc.getGreetMethod(), greetingRequest("Till")), suspensionMessage(2)) .named("None completed"), - testInvocation(new ReverseAwaitOrder(), GreeterGrpc.getGreetMethod()) + testInvocation(this::reverseAwaitOrder, GreeterGrpc.getGreetMethod()) .withInput( startMessage(1), inputMessage(GreetingRequest.newBuilder()), @@ -203,7 +57,7 @@ Stream definitions() { setStateMessage("A2", "TILL"), outputMessage(greetingResponse("FRANCESCO-TILL"))) .named("A1 and A2 completed later"), - testInvocation(new ReverseAwaitOrder(), GreeterGrpc.getGreetMethod()) + testInvocation(this::reverseAwaitOrder, GreeterGrpc.getGreetMethod()) .withInput( startMessage(1), inputMessage(GreetingRequest.newBuilder()), @@ -216,7 +70,7 @@ Stream definitions() { setStateMessage("A2", "TILL"), outputMessage(greetingResponse("FRANCESCO-TILL"))) .named("A2 and A1 completed later"), - testInvocation(new ReverseAwaitOrder(), GreeterGrpc.getGreetMethod()) + testInvocation(this::reverseAwaitOrder, GreeterGrpc.getGreetMethod()) .withInput( startMessage(1), inputMessage(GreetingRequest.newBuilder()), @@ -228,7 +82,7 @@ Stream definitions() { setStateMessage("A2", "TILL"), suspensionMessage(1)) .named("Only A2 completed"), - testInvocation(new ReverseAwaitOrder(), GreeterGrpc.getGreetMethod()) + testInvocation(this::reverseAwaitOrder, GreeterGrpc.getGreetMethod()) .withInput( startMessage(1), inputMessage(GreetingRequest.newBuilder()), @@ -241,7 +95,7 @@ Stream definitions() { .named("Only A1 completed"), // --- Await twice the same executable - testInvocation(new AwaitTwiceTheSameAwaitable(), GreeterGrpc.getGreetMethod()) + testInvocation(this::awaitTwiceTheSameAwaitable, GreeterGrpc.getGreetMethod()) .withInput( startMessage(1), inputMessage(GreetingRequest.newBuilder()), @@ -252,7 +106,7 @@ Stream definitions() { outputMessage(greetingResponse("FRANCESCO-FRANCESCO"))), // --- All combinator - testInvocation(new AwaitAll(), GreeterGrpc.getGreetMethod()) + testInvocation(this::awaitAll, GreeterGrpc.getGreetMethod()) .withInput(startMessage(1), inputMessage(GreetingRequest.newBuilder())) .usingAllThreadingModels() .expectingOutput( @@ -260,7 +114,7 @@ Stream definitions() { invokeMessage(GreeterGrpc.getGreetMethod(), greetingRequest("Till")), suspensionMessage(1, 2)) .named("No completions will suspend"), - testInvocation(new AwaitAll(), GreeterGrpc.getGreetMethod()) + testInvocation(this::awaitAll, GreeterGrpc.getGreetMethod()) .withInput( startMessage(3), inputMessage(GreetingRequest.newBuilder()), @@ -272,7 +126,7 @@ Stream definitions() { .usingAllThreadingModels() .expectingOutput(suspensionMessage(1)) .named("Only one completion will suspend"), - testInvocation(new AwaitAll(), GreeterGrpc.getGreetMethod()) + testInvocation(this::awaitAll, GreeterGrpc.getGreetMethod()) .withInput( startMessage(3), inputMessage(GreetingRequest.newBuilder()), @@ -301,7 +155,7 @@ Stream definitions() { .isEqualTo(outputMessage(greetingResponse("FRANCESCO-TILL"))); }) .named("Everything completed will generate the combinators message"), - testInvocation(new AwaitAll(), GreeterGrpc.getGreetMethod()) + testInvocation(this::awaitAll, GreeterGrpc.getGreetMethod()) .withInput( startMessage(4), inputMessage(GreetingRequest.newBuilder()), @@ -317,7 +171,7 @@ Stream definitions() { .usingAllThreadingModels() .expectingOutput(outputMessage(greetingResponse("FRANCESCO-TILL"))) .named("Replay the combinator"), - testInvocation(new AwaitAll(), GreeterGrpc.getGreetMethod()) + testInvocation(this::awaitAll, GreeterGrpc.getGreetMethod()) .withInput( startMessage(1), inputMessage(GreetingRequest.newBuilder()), @@ -330,7 +184,7 @@ Stream definitions() { combinatorsMessage(1, 2), outputMessage(greetingResponse("FRANCESCO-TILL"))) .named("Complete all asynchronously"), - testInvocation(new AwaitAll(), GreeterGrpc.getGreetMethod()) + testInvocation(this::awaitAll, GreeterGrpc.getGreetMethod()) .withInput( startMessage(1), inputMessage(GreetingRequest.newBuilder()), @@ -342,7 +196,7 @@ Stream definitions() { combinatorsMessage(1), outputMessage(new IllegalStateException("My error"))) .named("All fails on first failure"), - testInvocation(new AwaitAll(), GreeterGrpc.getGreetMethod()) + testInvocation(this::awaitAll, GreeterGrpc.getGreetMethod()) .withInput( startMessage(1), inputMessage(GreetingRequest.newBuilder()), @@ -357,7 +211,7 @@ Stream definitions() { .named("All fails on second failure"), // --- Any combinator - testInvocation(new AwaitAny(), GreeterGrpc.getGreetMethod()) + testInvocation(this::awaitAny, GreeterGrpc.getGreetMethod()) .withInput(startMessage(1), inputMessage(GreetingRequest.newBuilder())) .usingAllThreadingModels() .expectingOutput( @@ -365,7 +219,7 @@ Stream definitions() { invokeMessage(GreeterGrpc.getGreetMethod(), greetingRequest("Till")), suspensionMessage(1, 2)) .named("No completions will suspend"), - testInvocation(new AwaitAny(), GreeterGrpc.getGreetMethod()) + testInvocation(this::awaitAny, GreeterGrpc.getGreetMethod()) .withInput( startMessage(3), inputMessage(GreetingRequest.newBuilder()), @@ -377,7 +231,7 @@ Stream definitions() { .usingAllThreadingModels() .expectingOutput(combinatorsMessage(2), outputMessage(greetingResponse("TILL"))) .named("Only one completion will generate the combinators message"), - testInvocation(new AwaitAny(), GreeterGrpc.getGreetMethod()) + testInvocation(this::awaitAny, GreeterGrpc.getGreetMethod()) .withInput( startMessage(3), inputMessage(GreetingRequest.newBuilder()), @@ -390,7 +244,7 @@ Stream definitions() { .expectingOutput( combinatorsMessage(2), outputMessage(new IllegalStateException("My error"))) .named("Only one failure will generate the combinators message"), - testInvocation(new AwaitAny(), GreeterGrpc.getGreetMethod()) + testInvocation(this::awaitAny, GreeterGrpc.getGreetMethod()) .withInput( startMessage(3), inputMessage(GreetingRequest.newBuilder()), @@ -423,7 +277,7 @@ Stream definitions() { outputMessage(greetingResponse("TILL"))); }) .named("Everything completed will generate the combinators message"), - testInvocation(new AwaitAny(), GreeterGrpc.getGreetMethod()) + testInvocation(this::awaitAny, GreeterGrpc.getGreetMethod()) .withInput( startMessage(4), inputMessage(GreetingRequest.newBuilder()), @@ -439,7 +293,7 @@ Stream definitions() { .usingAllThreadingModels() .expectingOutput(outputMessage(greetingResponse("TILL"))) .named("Replay the combinator"), - testInvocation(new AwaitAny(), GreeterGrpc.getGreetMethod()) + testInvocation(this::awaitAny, GreeterGrpc.getGreetMethod()) .withInput( startMessage(1), inputMessage(GreetingRequest.newBuilder()), @@ -453,7 +307,7 @@ Stream definitions() { .named("Complete any asynchronously"), // --- Compose any with all - testInvocation(new CombineAnyWithAll(), GreeterGrpc.getGreetMethod()) + testInvocation(this::combineAnyWithAll, GreeterGrpc.getGreetMethod()) .withInput( startMessage(6), inputMessage(GreetingRequest.newBuilder()), @@ -464,7 +318,7 @@ Stream definitions() { combinatorsMessage(2, 3)) .usingAllThreadingModels() .expectingOutput(outputMessage(greetingResponse("223"))), - testInvocation(new CombineAnyWithAll(), GreeterGrpc.getGreetMethod()) + testInvocation(this::combineAnyWithAll, GreeterGrpc.getGreetMethod()) .withInput( startMessage(6), inputMessage(GreetingRequest.newBuilder()), @@ -478,7 +332,7 @@ Stream definitions() { .named("Inverted order"), // --- Await Any with index - testInvocation(new AnyAwaitIndex(), GreeterGrpc.getGreetMethod()) + testInvocation(this::awaitAnyIndex, GreeterGrpc.getGreetMethod()) .withInput( startMessage(6), inputMessage(GreetingRequest.newBuilder()), @@ -489,7 +343,7 @@ Stream definitions() { combinatorsMessage(1)) .usingAllThreadingModels() .expectingOutput(outputMessage(greetingResponse("0"))), - testInvocation(new AnyAwaitIndex(), GreeterGrpc.getGreetMethod()) + testInvocation(this::awaitAnyIndex, GreeterGrpc.getGreetMethod()) .withInput( startMessage(6), inputMessage(GreetingRequest.newBuilder()), @@ -503,7 +357,7 @@ Stream definitions() { .named("Complete all"), // --- Compose nested and resolved all should work - testInvocation(new AwaitOnAlreadyResolvedAwaitables(), GreeterGrpc.getGreetMethod()) + testInvocation(this::awaitOnAlreadyResolvedAwaitables, GreeterGrpc.getGreetMethod()) .withInput( startMessage(3), inputMessage(GreetingRequest.newBuilder()), @@ -526,7 +380,7 @@ Stream definitions() { }), // --- Await with timeout - testInvocation(new AwaitWithTimeout(), GreeterGrpc.getGreetMethod()) + testInvocation(this::awaitWithTimeout, GreeterGrpc.getGreetMethod()) .withInput( startMessage(1), inputMessage(GreetingRequest.newBuilder()), @@ -546,7 +400,7 @@ Stream definitions() { .element(3) .isEqualTo(outputMessage(greetingResponse("FRANCESCO"))); }), - testInvocation(new AwaitWithTimeout(), GreeterGrpc.getGreetMethod()) + testInvocation(this::awaitWithTimeout, GreeterGrpc.getGreetMethod()) .withInput( startMessage(1), inputMessage(GreetingRequest.newBuilder()), diff --git a/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/EagerStateTest.java b/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/EagerStateTestSuite.java similarity index 55% rename from sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/EagerStateTest.java rename to sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/EagerStateTestSuite.java index f72b2929..6c3285f0 100644 --- a/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/EagerStateTest.java +++ b/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/EagerStateTestSuite.java @@ -1,81 +1,24 @@ package dev.restate.sdk.core.impl; import static dev.restate.sdk.core.impl.CoreTestRunner.TestCaseBuilder.testInvocation; -import static dev.restate.sdk.core.impl.MessageHeader.*; import static dev.restate.sdk.core.impl.ProtoUtils.*; import static org.assertj.core.api.AssertionsForClassTypes.entry; import com.google.protobuf.MessageLite; -import dev.restate.sdk.blocking.RestateBlockingService; -import dev.restate.sdk.blocking.RestateContext; -import dev.restate.sdk.core.StateKey; -import dev.restate.sdk.core.TypeTag; import dev.restate.sdk.core.impl.testservices.GreeterGrpc; -import dev.restate.sdk.core.impl.testservices.GreetingRequest; -import dev.restate.sdk.core.impl.testservices.GreetingResponse; -import io.grpc.stub.StreamObserver; +import io.grpc.BindableService; import java.util.Map; import java.util.stream.Stream; -class EagerStateTest extends CoreTestRunner { +public abstract class EagerStateTestSuite extends CoreTestRunner { - private static class GetEmpty extends GreeterGrpc.GreeterImplBase - implements RestateBlockingService { - @Override - public void greet(GreetingRequest request, StreamObserver responseObserver) { - RestateContext ctx = restateContext(); + protected abstract BindableService getEmpty(); - boolean stateIsEmpty = ctx.get(StateKey.of("STATE", TypeTag.STRING_UTF8)).isEmpty(); + protected abstract BindableService get(); - responseObserver.onNext( - GreetingResponse.newBuilder().setMessage(String.valueOf(stateIsEmpty)).build()); - responseObserver.onCompleted(); - } - } - - private static class Get extends GreeterGrpc.GreeterImplBase implements RestateBlockingService { - @Override - public void greet(GreetingRequest request, StreamObserver responseObserver) { - RestateContext ctx = restateContext(); - - String state = ctx.get(StateKey.of("STATE", TypeTag.STRING_UTF8)).get(); - - responseObserver.onNext(GreetingResponse.newBuilder().setMessage(state).build()); - responseObserver.onCompleted(); - } - } - - private static class GetAppendAndGet extends GreeterGrpc.GreeterImplBase - implements RestateBlockingService { - @Override - public void greet(GreetingRequest request, StreamObserver responseObserver) { - RestateContext ctx = restateContext(); - - String oldState = ctx.get(StateKey.of("STATE", TypeTag.STRING_UTF8)).get(); - ctx.set(StateKey.of("STATE", TypeTag.STRING_UTF8), oldState + request.getName()); + protected abstract BindableService getAppendAndGet(); - String newState = ctx.get(StateKey.of("STATE", TypeTag.STRING_UTF8)).get(); - - responseObserver.onNext(GreetingResponse.newBuilder().setMessage(newState).build()); - responseObserver.onCompleted(); - } - } - - private static class GetClearAndGet extends GreeterGrpc.GreeterImplBase - implements RestateBlockingService { - @Override - public void greet(GreetingRequest request, StreamObserver responseObserver) { - RestateContext ctx = restateContext(); - - String oldState = ctx.get(StateKey.of("STATE", TypeTag.STRING_UTF8)).get(); - - ctx.clear(StateKey.of("STATE", TypeTag.STRING_UTF8)); - assert ctx.get(StateKey.of("STATE", TypeTag.STRING_UTF8)).isEmpty(); - - responseObserver.onNext(GreetingResponse.newBuilder().setMessage(oldState).build()); - responseObserver.onCompleted(); - } - } + protected abstract BindableService getClearAndGet(); private static final Map.Entry STATE_FRANCESCO = entry("STATE", "Francesco"); private static final MessageLite INPUT_TILL = inputMessage(greetingRequest("Till")); @@ -89,40 +32,40 @@ public void greet(GreetingRequest request, StreamObserver resp outputMessage(greetingResponse("FrancescoTill")); @Override - Stream definitions() { + protected Stream definitions() { return Stream.of( - testInvocation(new GetEmpty(), GreeterGrpc.getGreetMethod()) + testInvocation(this::getEmpty, GreeterGrpc.getGreetMethod()) .withInput(startMessage(1).setPartialState(false), INPUT_TILL) .usingAllThreadingModels() .expectingOutput(getStateEmptyMessage("STATE"), outputMessage(greetingResponse("true"))) .named("With complete state"), - testInvocation(new GetEmpty(), GreeterGrpc.getGreetMethod()) + testInvocation(this::getEmpty, GreeterGrpc.getGreetMethod()) .withInput(startMessage(1).setPartialState(true), INPUT_TILL) .usingAllThreadingModels() .expectingOutput(getStateMessage("STATE"), suspensionMessage(1)) .named("With partial state"), - testInvocation(new GetEmpty(), GreeterGrpc.getGreetMethod()) + testInvocation(this::getEmpty, GreeterGrpc.getGreetMethod()) .withInput( startMessage(2).setPartialState(true), INPUT_TILL, getStateEmptyMessage("STATE")) .usingAllThreadingModels() .expectingOutput(outputMessage(greetingResponse("true"))) .named("Resume with partial state"), - testInvocation(new Get(), GreeterGrpc.getGreetMethod()) + testInvocation(this::get, GreeterGrpc.getGreetMethod()) .withInput(startMessage(1, STATE_FRANCESCO).setPartialState(false), INPUT_TILL) .usingAllThreadingModels() .expectingOutput(GET_STATE_FRANCESCO, OUTPUT_FRANCESCO) .named("With complete state"), - testInvocation(new Get(), GreeterGrpc.getGreetMethod()) + testInvocation(this::get, GreeterGrpc.getGreetMethod()) .withInput(startMessage(1, STATE_FRANCESCO).setPartialState(true), INPUT_TILL) .usingAllThreadingModels() .expectingOutput(GET_STATE_FRANCESCO, OUTPUT_FRANCESCO) .named("With partial state"), - testInvocation(new Get(), GreeterGrpc.getGreetMethod()) + testInvocation(this::get, GreeterGrpc.getGreetMethod()) .withInput(startMessage(1).setPartialState(true), INPUT_TILL) .usingAllThreadingModels() .expectingOutput(getStateMessage("STATE"), suspensionMessage(1)) .named("With partial state without the state entry"), - testInvocation(new GetAppendAndGet(), GreeterGrpc.getGreetMethod()) + testInvocation(this::getAppendAndGet, GreeterGrpc.getGreetMethod()) .withInput(startMessage(1, STATE_FRANCESCO), INPUT_TILL) .usingAllThreadingModels() .expectingOutput( @@ -131,7 +74,7 @@ Stream definitions() { GET_STATE_FRANCESCO_TILL, OUTPUT_FRANCESCO_TILL) .named("With state in the state_map"), - testInvocation(new GetAppendAndGet(), GreeterGrpc.getGreetMethod()) + testInvocation(this::getAppendAndGet, GreeterGrpc.getGreetMethod()) .withInput( startMessage(1).setPartialState(true), INPUT_TILL, @@ -143,7 +86,7 @@ Stream definitions() { GET_STATE_FRANCESCO_TILL, OUTPUT_FRANCESCO_TILL) .named("With partial state on the first get"), - testInvocation(new GetClearAndGet(), GreeterGrpc.getGreetMethod()) + testInvocation(this::getClearAndGet, GreeterGrpc.getGreetMethod()) .withInput(startMessage(1, STATE_FRANCESCO), INPUT_TILL) .usingAllThreadingModels() .expectingOutput( @@ -152,7 +95,7 @@ Stream definitions() { getStateEmptyMessage("STATE"), OUTPUT_FRANCESCO) .named("With state in the state_map"), - testInvocation(new GetClearAndGet(), GreeterGrpc.getGreetMethod()) + testInvocation(this::getClearAndGet, GreeterGrpc.getGreetMethod()) .withInput( startMessage(1).setPartialState(true), INPUT_TILL, diff --git a/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/FlowUtils.java b/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/FlowUtils.java index d96a5ab4..2b601e61 100644 --- a/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/FlowUtils.java +++ b/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/FlowUtils.java @@ -7,45 +7,6 @@ public class FlowUtils { - public static class CollectorSubscriber implements Flow.Subscriber { - - private final List msgs = new ArrayList<>(); - private Throwable error = null; - private boolean completed = false; - - @Override - public void onSubscribe(Flow.Subscription subscription) { - subscription.request(Long.MAX_VALUE); - } - - @Override - public void onNext(T t) { - this.msgs.add(t); - } - - @Override - public void onError(Throwable throwable) { - this.error = throwable; - } - - @Override - public void onComplete() { - this.completed = true; - } - - public List getMessages() { - return msgs; - } - - public Throwable getError() { - return error; - } - - public boolean isCompleted() { - return completed; - } - } - public static class FutureSubscriber implements Flow.Subscriber { private final List messages = new ArrayList<>(); diff --git a/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/GetAndSetStateTest.java b/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/GetAndSetStateTestSuite.java similarity index 53% rename from sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/GetAndSetStateTest.java rename to sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/GetAndSetStateTestSuite.java index e4af6cda..2813edc9 100644 --- a/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/GetAndSetStateTest.java +++ b/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/GetAndSetStateTestSuite.java @@ -4,59 +4,22 @@ import static dev.restate.sdk.core.impl.CoreTestRunner.TestCaseBuilder.testInvocation; import static dev.restate.sdk.core.impl.ProtoUtils.*; -import dev.restate.sdk.blocking.RestateBlockingService; -import dev.restate.sdk.blocking.RestateContext; -import dev.restate.sdk.core.StateKey; -import dev.restate.sdk.core.TypeTag; import dev.restate.sdk.core.impl.testservices.GreeterGrpc; import dev.restate.sdk.core.impl.testservices.GreetingRequest; import dev.restate.sdk.core.impl.testservices.GreetingResponse; -import io.grpc.stub.StreamObserver; +import io.grpc.BindableService; import java.util.stream.Stream; -class GetAndSetStateTest extends CoreTestRunner { +public abstract class GetAndSetStateTestSuite extends CoreTestRunner { - private static class GetAndSetGreeter extends GreeterGrpc.GreeterImplBase - implements RestateBlockingService { - @Override - public void greet(GreetingRequest request, StreamObserver responseObserver) { - RestateContext ctx = restateContext(); + protected abstract BindableService getAndSetGreeter(); - String state = ctx.get(StateKey.of("STATE", TypeTag.STRING_UTF8)).get(); - - ctx.set(StateKey.of("STATE", TypeTag.STRING_UTF8), request.getName()); - - responseObserver.onNext(GreetingResponse.newBuilder().setMessage("Hello " + state).build()); - responseObserver.onCompleted(); - } - } - - private static class SetNullState extends GreeterGrpc.GreeterImplBase - implements RestateBlockingService { - @Override - public void greet(GreetingRequest request, StreamObserver responseObserver) { - restateContext() - .set( - StateKey.of( - "STATE", - TypeTag.using( - l -> { - throw new IllegalStateException("Unexpected call to serde fn"); - }, - l -> { - throw new IllegalStateException("Unexpected call to serde fn"); - })), - null); - - responseObserver.onNext(greetingResponse("")); - responseObserver.onCompleted(); - } - } + protected abstract BindableService setNullState(); @Override - Stream definitions() { + protected Stream definitions() { return Stream.of( - testInvocation(new GetAndSetGreeter(), GreeterGrpc.getGreetMethod()) + testInvocation(this::getAndSetGreeter, GreeterGrpc.getGreetMethod()) .withInput( startMessage(3), inputMessage(GreetingRequest.newBuilder().setName("Till")), @@ -66,7 +29,7 @@ Stream definitions() { .expectingOutput( outputMessage(GreetingResponse.newBuilder().setMessage("Hello Francesco"))) .named("With GetState and SetState"), - testInvocation(new GetAndSetGreeter(), GreeterGrpc.getGreetMethod()) + testInvocation(this::getAndSetGreeter, GreeterGrpc.getGreetMethod()) .withInput( startMessage(2), inputMessage(GreetingRequest.newBuilder().setName("Till")), @@ -76,7 +39,7 @@ Stream definitions() { setStateMessage("STATE", "Till"), outputMessage(GreetingResponse.newBuilder().setMessage("Hello Francesco"))) .named("With GetState already completed"), - testInvocation(new GetAndSetGreeter(), GreeterGrpc.getGreetMethod()) + testInvocation(this::getAndSetGreeter, GreeterGrpc.getGreetMethod()) .withInput( startMessage(1), inputMessage(GreetingRequest.newBuilder().setName("Till")), @@ -87,7 +50,7 @@ Stream definitions() { setStateMessage("STATE", "Till"), outputMessage(GreetingResponse.newBuilder().setMessage("Hello Francesco"))) .named("With GetState completed later"), - testInvocation(new SetNullState(), GreeterGrpc.getGreetMethod()) + testInvocation(this::setNullState, GreeterGrpc.getGreetMethod()) .withInput(startMessage(1), inputMessage(GreetingRequest.newBuilder().setName("Till"))) .usingAllThreadingModels() .assertingOutput(containsOnlyExactErrorMessage(new NullPointerException()))); diff --git a/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/GetStateTest.java b/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/GetStateTestSuite.java similarity index 72% rename from sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/GetStateTest.java rename to sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/GetStateTestSuite.java index 74f8df48..d7580fe2 100644 --- a/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/GetStateTest.java +++ b/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/GetStateTestSuite.java @@ -4,33 +4,20 @@ import static dev.restate.sdk.core.impl.ProtoUtils.*; import com.google.protobuf.Empty; -import dev.restate.sdk.blocking.RestateBlockingService; -import dev.restate.sdk.core.StateKey; -import dev.restate.sdk.core.TypeTag; import dev.restate.sdk.core.impl.testservices.GreeterGrpc; import dev.restate.sdk.core.impl.testservices.GreetingRequest; import dev.restate.sdk.core.impl.testservices.GreetingResponse; -import io.grpc.stub.StreamObserver; +import io.grpc.BindableService; import java.util.stream.Stream; -class GetStateTest extends CoreTestRunner { +public abstract class GetStateTestSuite extends CoreTestRunner { - private static class GetStateGreeter extends GreeterGrpc.GreeterImplBase - implements RestateBlockingService { - @Override - public void greet(GreetingRequest request, StreamObserver responseObserver) { - String state = - restateContext().get(StateKey.of("STATE", TypeTag.STRING_UTF8)).orElse("Unknown"); - - responseObserver.onNext(GreetingResponse.newBuilder().setMessage("Hello " + state).build()); - responseObserver.onCompleted(); - } - } + protected abstract BindableService getStateGreeter(); @Override - Stream definitions() { + protected Stream definitions() { return Stream.of( - testInvocation(new GetStateGreeter(), GreeterGrpc.getGreetMethod()) + testInvocation(this::getStateGreeter, GreeterGrpc.getGreetMethod()) .withInput( startMessage(2), inputMessage(GreetingRequest.newBuilder().setName("Till")), @@ -39,7 +26,7 @@ Stream definitions() { .expectingOutput( outputMessage(GreetingResponse.newBuilder().setMessage("Hello Francesco"))) .named("With GetStateEntry already completed"), - testInvocation(new GetStateGreeter(), GreeterGrpc.getGreetMethod()) + testInvocation(this::getStateGreeter, GreeterGrpc.getGreetMethod()) .withInput( startMessage(2), inputMessage(GreetingRequest.newBuilder().setName("Till")), @@ -48,12 +35,12 @@ Stream definitions() { .expectingOutput( outputMessage(GreetingResponse.newBuilder().setMessage("Hello Unknown"))) .named("With GetStateEntry already completed empty"), - testInvocation(new GetStateGreeter(), GreeterGrpc.getGreetMethod()) + testInvocation(this::getStateGreeter, GreeterGrpc.getGreetMethod()) .withInput(startMessage(1), inputMessage(GreetingRequest.newBuilder().setName("Till"))) .usingAllThreadingModels() .expectingOutput(getStateMessage("STATE"), suspensionMessage(1)) .named("Without GetStateEntry"), - testInvocation(new GetStateGreeter(), GreeterGrpc.getGreetMethod()) + testInvocation(this::getStateGreeter, GreeterGrpc.getGreetMethod()) .withInput( startMessage(2), inputMessage(GreetingRequest.newBuilder().setName("Till").build()), @@ -61,7 +48,7 @@ Stream definitions() { .usingAllThreadingModels() .expectingOutput(suspensionMessage(1)) .named("With GetStateEntry not completed"), - testInvocation(new GetStateGreeter(), GreeterGrpc.getGreetMethod()) + testInvocation(this::getStateGreeter, GreeterGrpc.getGreetMethod()) .withInput( startMessage(2), inputMessage(GreetingRequest.newBuilder().setName("Till")), @@ -71,7 +58,7 @@ Stream definitions() { .expectingOutput( outputMessage(GreetingResponse.newBuilder().setMessage("Hello Francesco"))) .named("With GetStateEntry and completed with later CompletionFrame"), - testInvocation(new GetStateGreeter(), GreeterGrpc.getGreetMethod()) + testInvocation(this::getStateGreeter, GreeterGrpc.getGreetMethod()) .withInput( startMessage(1), inputMessage(GreetingRequest.newBuilder().setName("Till")), diff --git a/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/InvocationIdTest.java b/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/InvocationIdTestSuite.java similarity index 56% rename from sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/InvocationIdTest.java rename to sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/InvocationIdTestSuite.java index f6b47dfa..ba320337 100644 --- a/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/InvocationIdTest.java +++ b/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/InvocationIdTestSuite.java @@ -5,33 +5,22 @@ import com.google.protobuf.ByteString; import dev.restate.generated.service.protocol.Protocol; -import dev.restate.sdk.blocking.RestateBlockingService; -import dev.restate.sdk.core.InvocationId; import dev.restate.sdk.core.impl.testservices.GreeterGrpc; import dev.restate.sdk.core.impl.testservices.GreetingRequest; -import dev.restate.sdk.core.impl.testservices.GreetingResponse; -import io.grpc.stub.StreamObserver; +import io.grpc.BindableService; import java.util.stream.Stream; -class InvocationIdTest extends CoreTestRunner { +public abstract class InvocationIdTestSuite extends CoreTestRunner { - private static class ReturnInvocationId extends GreeterGrpc.GreeterImplBase - implements RestateBlockingService { - - @Override - public void greet(GreetingRequest request, StreamObserver responseObserver) { - responseObserver.onNext(greetingResponse(InvocationId.current().toString())); - responseObserver.onCompleted(); - } - } + protected abstract BindableService returnInvocationId(); @Override - Stream definitions() { + protected Stream definitions() { String debugId = "my-debug-id"; ByteString id = ByteString.copyFromUtf8(debugId); return Stream.of( - testInvocation(new ReturnInvocationId(), GreeterGrpc.getGreetMethod()) + testInvocation(this::returnInvocationId, GreeterGrpc.getGreetMethod()) .withInput( Protocol.StartMessage.newBuilder().setDebugId(debugId).setId(id).setKnownEntries(1), inputMessage(GreetingRequest.getDefaultInstance())) diff --git a/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/OnlyInputAndOutputTest.java b/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/OnlyInputAndOutputTestSuite.java similarity index 57% rename from sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/OnlyInputAndOutputTest.java rename to sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/OnlyInputAndOutputTestSuite.java index 503a0fed..e7e4ae48 100644 --- a/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/OnlyInputAndOutputTest.java +++ b/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/OnlyInputAndOutputTestSuite.java @@ -6,24 +6,17 @@ import dev.restate.sdk.core.impl.testservices.GreeterGrpc; import dev.restate.sdk.core.impl.testservices.GreetingRequest; import dev.restate.sdk.core.impl.testservices.GreetingResponse; -import io.grpc.stub.StreamObserver; +import io.grpc.BindableService; import java.util.stream.Stream; -class OnlyInputAndOutputTest extends CoreTestRunner { +public abstract class OnlyInputAndOutputTestSuite extends CoreTestRunner { - private static class NoSyscallsGreeter extends GreeterGrpc.GreeterImplBase { - @Override - public void greet(GreetingRequest request, StreamObserver responseObserver) { - responseObserver.onNext( - GreetingResponse.newBuilder().setMessage("Hello " + request.getName()).build()); - responseObserver.onCompleted(); - } - } + protected abstract BindableService noSyscallsGreeter(); @Override - Stream definitions() { + protected Stream definitions() { return Stream.of( - testInvocation(new NoSyscallsGreeter(), GreeterGrpc.getGreetMethod()) + testInvocation(this::noSyscallsGreeter, GreeterGrpc.getGreetMethod()) .withInput( startMessage(1), inputMessage(GreetingRequest.newBuilder().setName("Francesco"))) .usingAllThreadingModels() diff --git a/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/ProtoUtils.java b/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/ProtoUtils.java index cef70224..5fed1814 100644 --- a/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/ProtoUtils.java +++ b/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/ProtoUtils.java @@ -71,14 +71,14 @@ public static Protocol.CompletionMessage completionMessage( .build(); } - static Protocol.CompletionMessage completionMessage(int index, Throwable e) { + public static Protocol.CompletionMessage completionMessage(int index, Throwable e) { return Protocol.CompletionMessage.newBuilder() .setEntryIndex(index) .setFailure(toProtocolFailure(Status.INTERNAL.withDescription(e.getMessage()))) .build(); } - static Protocol.SuspensionMessage suspensionMessage(Integer... indexes) { + public static Protocol.SuspensionMessage suspensionMessage(Integer... indexes) { return Protocol.SuspensionMessage.newBuilder().addAllEntryIndexes(List.of(indexes)).build(); } @@ -88,53 +88,53 @@ public static Protocol.PollInputStreamEntryMessage inputMessage(MessageLiteOrBui .build(); } - static Protocol.OutputStreamEntryMessage outputMessage(MessageLiteOrBuilder value) { + public static Protocol.OutputStreamEntryMessage outputMessage(MessageLiteOrBuilder value) { return Protocol.OutputStreamEntryMessage.newBuilder() .setValue(build(value).toByteString()) .build(); } - static Protocol.OutputStreamEntryMessage outputMessage(Status s) { + public static Protocol.OutputStreamEntryMessage outputMessage(Status s) { return Protocol.OutputStreamEntryMessage.newBuilder() .setFailure(Util.toProtocolFailure(s.asRuntimeException())) .build(); } - static Protocol.OutputStreamEntryMessage outputMessage(Throwable e) { + public static Protocol.OutputStreamEntryMessage outputMessage(Throwable e) { return Protocol.OutputStreamEntryMessage.newBuilder() .setFailure(toProtocolFailure(Status.INTERNAL.withDescription(e.getMessage()))) .build(); } - static Protocol.GetStateEntryMessage.Builder getStateMessage(String key) { + public static Protocol.GetStateEntryMessage.Builder getStateMessage(String key) { return Protocol.GetStateEntryMessage.newBuilder().setKey(ByteString.copyFromUtf8(key)); } - static Protocol.GetStateEntryMessage getStateEmptyMessage(String key) { + public static Protocol.GetStateEntryMessage getStateEmptyMessage(String key) { return Protocol.GetStateEntryMessage.newBuilder() .setKey(ByteString.copyFromUtf8(key)) .setEmpty(Empty.getDefaultInstance()) .build(); } - static Protocol.GetStateEntryMessage getStateMessage(String key, String value) { + public static Protocol.GetStateEntryMessage getStateMessage(String key, String value) { return getStateMessage(key).setValue(ByteString.copyFromUtf8(value)).build(); } - static Protocol.SetStateEntryMessage setStateMessage(String key, String value) { + public static Protocol.SetStateEntryMessage setStateMessage(String key, String value) { return Protocol.SetStateEntryMessage.newBuilder() .setKey(ByteString.copyFromUtf8(key)) .setValue(ByteString.copyFromUtf8(value)) .build(); } - static Protocol.ClearStateEntryMessage clearStateMessage(String key) { + public static Protocol.ClearStateEntryMessage clearStateMessage(String key) { return Protocol.ClearStateEntryMessage.newBuilder() .setKey(ByteString.copyFromUtf8(key)) .build(); } - static + public static Protocol.InvokeEntryMessage.Builder invokeMessage( MethodDescriptor methodDescriptor, T parameter) { return Protocol.InvokeEntryMessage.newBuilder() @@ -143,35 +143,37 @@ Protocol.InvokeEntryMessage.Builder invokeMessage( .setParameter(parameter.toByteString()); } - static Protocol.InvokeEntryMessage invokeMessage( - MethodDescriptor methodDescriptor, T parameter, R result) { + public static + Protocol.InvokeEntryMessage invokeMessage( + MethodDescriptor methodDescriptor, T parameter, R result) { return invokeMessage(methodDescriptor, parameter).setValue(result.toByteString()).build(); } - static Protocol.InvokeEntryMessage invokeMessage( - MethodDescriptor methodDescriptor, T parameter, Throwable e) { + public static + Protocol.InvokeEntryMessage invokeMessage( + MethodDescriptor methodDescriptor, T parameter, Throwable e) { return invokeMessage(methodDescriptor, parameter) .setFailure(toProtocolFailure(Status.INTERNAL.withDescription(e.getMessage()))) .build(); } - static Protocol.AwakeableEntryMessage.Builder awakeable() { + public static Protocol.AwakeableEntryMessage.Builder awakeable() { return Protocol.AwakeableEntryMessage.newBuilder(); } - static Protocol.AwakeableEntryMessage awakeable(String value) { + public static Protocol.AwakeableEntryMessage awakeable(String value) { return awakeable().setValue(ByteString.copyFromUtf8(value)).build(); } - static GreetingRequest greetingRequest(String name) { + public static GreetingRequest greetingRequest(String name) { return GreetingRequest.newBuilder().setName(name).build(); } - static GreetingResponse greetingResponse(String message) { + public static GreetingResponse greetingResponse(String message) { return GreetingResponse.newBuilder().setMessage(message).build(); } - static Java.CombinatorAwaitableEntryMessage combinatorsMessage(Integer... order) { + public static Java.CombinatorAwaitableEntryMessage combinatorsMessage(Integer... order) { return Java.CombinatorAwaitableEntryMessage.newBuilder() .addAllEntryIndex(Arrays.asList(order)) .build(); diff --git a/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/SideEffectTest.java b/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/SideEffectTest.java deleted file mode 100644 index e14faa42..00000000 --- a/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/SideEffectTest.java +++ /dev/null @@ -1,185 +0,0 @@ -package dev.restate.sdk.core.impl; - -import static dev.restate.sdk.core.impl.AssertUtils.containsOnlyExactErrorMessage; -import static dev.restate.sdk.core.impl.CoreTestRunner.TestCaseBuilder.testInvocation; -import static dev.restate.sdk.core.impl.ProtoUtils.*; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.InstanceOfAssertFactories.type; - -import com.google.protobuf.ByteString; -import dev.restate.generated.sdk.java.Java; -import dev.restate.generated.service.protocol.Protocol; -import dev.restate.sdk.blocking.RestateBlockingService; -import dev.restate.sdk.blocking.RestateContext; -import dev.restate.sdk.core.TypeTag; -import dev.restate.sdk.core.impl.testservices.GreeterGrpc; -import dev.restate.sdk.core.impl.testservices.GreetingRequest; -import dev.restate.sdk.core.impl.testservices.GreetingResponse; -import io.grpc.stub.StreamObserver; -import java.util.Objects; -import java.util.stream.Stream; - -class SideEffectTest extends CoreTestRunner { - - private static class SideEffect extends GreeterGrpc.GreeterImplBase - implements RestateBlockingService { - - private final String sideEffectOutput; - - SideEffect(String sideEffectOutput) { - this.sideEffectOutput = sideEffectOutput; - } - - @Override - public void greet(GreetingRequest request, StreamObserver responseObserver) { - RestateContext ctx = restateContext(); - - String result = ctx.sideEffect(TypeTag.STRING_UTF8, () -> this.sideEffectOutput); - - responseObserver.onNext(GreetingResponse.newBuilder().setMessage("Hello " + result).build()); - responseObserver.onCompleted(); - } - } - - private static class ConsecutiveSideEffect extends GreeterGrpc.GreeterImplBase - implements RestateBlockingService { - - private final String sideEffectOutput; - - ConsecutiveSideEffect(String sideEffectOutput) { - this.sideEffectOutput = sideEffectOutput; - } - - @Override - public void greet(GreetingRequest request, StreamObserver responseObserver) { - RestateContext ctx = restateContext(); - - String firstResult = ctx.sideEffect(TypeTag.STRING_UTF8, () -> this.sideEffectOutput); - String secondResult = ctx.sideEffect(TypeTag.STRING_UTF8, firstResult::toUpperCase); - - responseObserver.onNext( - GreetingResponse.newBuilder().setMessage("Hello " + secondResult).build()); - responseObserver.onCompleted(); - } - } - - private static class CheckContextSwitching extends GreeterGrpc.GreeterImplBase - implements RestateBlockingService { - - @Override - public void greet(GreetingRequest request, StreamObserver responseObserver) { - String currentThread = Thread.currentThread().getName(); - - String sideEffectThread = - restateContext().sideEffect(TypeTag.STRING_UTF8, () -> Thread.currentThread().getName()); - - if (!Objects.equals(currentThread, sideEffectThread)) { - throw new IllegalStateException( - "Current thread and side effect thread do not match: " - + currentThread - + " != " - + sideEffectThread); - } - - responseObserver.onNext(GreetingResponse.newBuilder().setMessage("Hello").build()); - responseObserver.onCompleted(); - } - } - - private static class SideEffectGuard extends GreeterGrpc.GreeterImplBase - implements RestateBlockingService { - - @Override - public void greet(GreetingRequest request, StreamObserver responseObserver) { - RestateContext ctx = restateContext(); - ctx.sideEffect( - () -> ctx.oneWayCall(GreeterGrpc.getGreetMethod(), greetingRequest("something"))); - - throw new IllegalStateException("This point should not be reached"); - } - } - - private static class SideEffectThenAwakeable extends GreeterGrpc.GreeterImplBase - implements RestateBlockingService { - - @Override - public void greet(GreetingRequest request, StreamObserver responseObserver) { - RestateContext ctx = restateContext(); - - ctx.sideEffect( - () -> { - throw new IllegalStateException("This should be replayed"); - }); - ctx.awakeable(TypeTag.BYTES).await(); - - responseObserver.onNext(GreetingResponse.newBuilder().setMessage("Hello").build()); - responseObserver.onCompleted(); - } - } - - @Override - Stream definitions() { - return Stream.of( - testInvocation(new SideEffect("Francesco"), GreeterGrpc.getGreetMethod()) - .withInput(startMessage(1), inputMessage(GreetingRequest.newBuilder().setName("Till"))) - .usingAllThreadingModels() - .expectingOutput( - Java.SideEffectEntryMessage.newBuilder() - .setValue(ByteString.copyFromUtf8("Francesco")), - outputMessage(GreetingResponse.newBuilder().setMessage("Hello Francesco"))), - testInvocation(new ConsecutiveSideEffect("Francesco"), GreeterGrpc.getGreetMethod()) - .withInput(startMessage(1), inputMessage(GreetingRequest.newBuilder().setName("Till"))) - .usingAllThreadingModels() - .expectingOutput( - Java.SideEffectEntryMessage.newBuilder() - .setValue(ByteString.copyFromUtf8("Francesco")), - suspensionMessage(1)) - .named("Without ack"), - testInvocation(new ConsecutiveSideEffect("Francesco"), GreeterGrpc.getGreetMethod()) - .withInput( - startMessage(1), - inputMessage(GreetingRequest.newBuilder().setName("Till")), - Protocol.CompletionMessage.newBuilder().setEntryIndex(1)) - .usingThreadingModels(ThreadingModel.UNBUFFERED_MULTI_THREAD) - .expectingOutput( - Java.SideEffectEntryMessage.newBuilder() - .setValue(ByteString.copyFromUtf8("Francesco")), - Java.SideEffectEntryMessage.newBuilder() - .setValue(ByteString.copyFromUtf8("FRANCESCO")), - outputMessage(GreetingResponse.newBuilder().setMessage("Hello FRANCESCO"))) - .named("With ack"), - testInvocation(new CheckContextSwitching(), GreeterGrpc.getGreetMethod()) - .withInput(startMessage(1), inputMessage(GreetingRequest.getDefaultInstance())) - .usingThreadingModels(ThreadingModel.UNBUFFERED_MULTI_THREAD) - .assertingOutput( - actualOutputMessages -> { - assertThat(actualOutputMessages).hasSize(2); - assertThat(actualOutputMessages) - .element(0) - .asInstanceOf(type(Java.SideEffectEntryMessage.class)) - .returns(true, Java.SideEffectEntryMessage::hasValue); - assertThat(actualOutputMessages) - .element(1) - .isEqualTo( - Protocol.OutputStreamEntryMessage.newBuilder() - .setValue( - GreetingResponse.newBuilder() - .setMessage("Hello") - .build() - .toByteString()) - .build()); - }), - testInvocation(new SideEffectGuard(), GreeterGrpc.getGreetMethod()) - .withInput(startMessage(1), inputMessage(GreetingRequest.newBuilder().setName("Till"))) - .usingAllThreadingModels() - .assertingOutput( - containsOnlyExactErrorMessage(ProtocolException.invalidSideEffectCall())), - testInvocation(new SideEffectThenAwakeable(), GreeterGrpc.getGreetMethod()) - .withInput( - startMessage(2), - inputMessage(GreetingRequest.newBuilder().setName("Till")), - Java.SideEffectEntryMessage.newBuilder().setValue(ByteString.copyFromUtf8(""))) - .usingAllThreadingModels() - .expectingOutput(awakeable(), suspensionMessage(2))); - } -} diff --git a/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/SideEffectTestSuite.java b/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/SideEffectTestSuite.java new file mode 100644 index 00000000..e73ecc4b --- /dev/null +++ b/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/SideEffectTestSuite.java @@ -0,0 +1,95 @@ +package dev.restate.sdk.core.impl; + +import static dev.restate.sdk.core.impl.AssertUtils.containsOnlyExactErrorMessage; +import static dev.restate.sdk.core.impl.CoreTestRunner.TestCaseBuilder.testInvocation; +import static dev.restate.sdk.core.impl.ProtoUtils.*; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.InstanceOfAssertFactories.type; + +import com.google.protobuf.ByteString; +import dev.restate.generated.sdk.java.Java; +import dev.restate.generated.service.protocol.Protocol; +import dev.restate.sdk.core.impl.testservices.GreeterGrpc; +import dev.restate.sdk.core.impl.testservices.GreetingRequest; +import dev.restate.sdk.core.impl.testservices.GreetingResponse; +import io.grpc.BindableService; +import java.util.stream.Stream; + +public abstract class SideEffectTestSuite extends CoreTestRunner { + + protected abstract BindableService sideEffect(String sideEffectOutput); + + protected abstract BindableService consecutiveSideEffect(String sideEffectOutput); + + protected abstract BindableService checkContextSwitching(); + + protected abstract BindableService sideEffectGuard(); + + protected abstract BindableService sideEffectThenAwakeable(); + + @Override + protected Stream definitions() { + return Stream.of( + testInvocation(() -> this.sideEffect("Francesco"), GreeterGrpc.getGreetMethod()) + .withInput(startMessage(1), inputMessage(GreetingRequest.newBuilder().setName("Till"))) + .usingAllThreadingModels() + .expectingOutput( + Java.SideEffectEntryMessage.newBuilder() + .setValue(ByteString.copyFromUtf8("Francesco")), + outputMessage(GreetingResponse.newBuilder().setMessage("Hello Francesco"))), + testInvocation(() -> this.consecutiveSideEffect("Francesco"), GreeterGrpc.getGreetMethod()) + .withInput(startMessage(1), inputMessage(GreetingRequest.newBuilder().setName("Till"))) + .usingAllThreadingModels() + .expectingOutput( + Java.SideEffectEntryMessage.newBuilder() + .setValue(ByteString.copyFromUtf8("Francesco")), + suspensionMessage(1)) + .named("Without ack"), + testInvocation(() -> this.consecutiveSideEffect("Francesco"), GreeterGrpc.getGreetMethod()) + .withInput( + startMessage(1), + inputMessage(GreetingRequest.newBuilder().setName("Till")), + Protocol.CompletionMessage.newBuilder().setEntryIndex(1)) + .usingThreadingModels(ThreadingModel.UNBUFFERED_MULTI_THREAD) + .expectingOutput( + Java.SideEffectEntryMessage.newBuilder() + .setValue(ByteString.copyFromUtf8("Francesco")), + Java.SideEffectEntryMessage.newBuilder() + .setValue(ByteString.copyFromUtf8("FRANCESCO")), + outputMessage(GreetingResponse.newBuilder().setMessage("Hello FRANCESCO"))) + .named("With ack"), + testInvocation(this::checkContextSwitching, GreeterGrpc.getGreetMethod()) + .withInput(startMessage(1), inputMessage(GreetingRequest.getDefaultInstance())) + .usingThreadingModels(ThreadingModel.UNBUFFERED_MULTI_THREAD) + .assertingOutput( + actualOutputMessages -> { + assertThat(actualOutputMessages).hasSize(2); + assertThat(actualOutputMessages) + .element(0) + .asInstanceOf(type(Java.SideEffectEntryMessage.class)) + .returns(true, Java.SideEffectEntryMessage::hasValue); + assertThat(actualOutputMessages) + .element(1) + .isEqualTo( + Protocol.OutputStreamEntryMessage.newBuilder() + .setValue( + GreetingResponse.newBuilder() + .setMessage("Hello") + .build() + .toByteString()) + .build()); + }), + testInvocation(this::sideEffectGuard, GreeterGrpc.getGreetMethod()) + .withInput(startMessage(1), inputMessage(GreetingRequest.newBuilder().setName("Till"))) + .usingAllThreadingModels() + .assertingOutput( + containsOnlyExactErrorMessage(ProtocolException.invalidSideEffectCall())), + testInvocation(this::sideEffectThenAwakeable, GreeterGrpc.getGreetMethod()) + .withInput( + startMessage(2), + inputMessage(GreetingRequest.newBuilder().setName("Till")), + Java.SideEffectEntryMessage.newBuilder().setValue(ByteString.copyFromUtf8(""))) + .usingAllThreadingModels() + .expectingOutput(awakeable(), suspensionMessage(2))); + } +} diff --git a/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/SleepTest.java b/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/SleepTestSuite.java similarity index 66% rename from sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/SleepTest.java rename to sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/SleepTestSuite.java index b4afa19e..f8872def 100644 --- a/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/SleepTest.java +++ b/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/SleepTestSuite.java @@ -7,65 +7,26 @@ import com.google.protobuf.Empty; import com.google.protobuf.MessageLiteOrBuilder; import dev.restate.generated.service.protocol.Protocol; -import dev.restate.sdk.blocking.Awaitable; -import dev.restate.sdk.blocking.RestateBlockingService; -import dev.restate.sdk.blocking.RestateContext; import dev.restate.sdk.core.impl.testservices.GreeterGrpc; import dev.restate.sdk.core.impl.testservices.GreetingRequest; import dev.restate.sdk.core.impl.testservices.GreetingResponse; -import io.grpc.stub.StreamObserver; -import java.time.Duration; +import io.grpc.BindableService; import java.time.Instant; -import java.util.ArrayList; -import java.util.List; import java.util.stream.IntStream; import java.util.stream.Stream; -public class SleepTest extends CoreTestRunner { +public abstract class SleepTestSuite extends CoreTestRunner { Long startTime = System.currentTimeMillis(); - private static class SleepGreeter extends GreeterGrpc.GreeterImplBase - implements RestateBlockingService { + protected abstract BindableService sleepGreeter(); - @Override - public void greet(GreetingRequest request, StreamObserver responseObserver) { - RestateContext ctx = restateContext(); - - ctx.sleep(Duration.ofMillis(1000)); - - responseObserver.onNext(GreetingResponse.newBuilder().setMessage("Hello").build()); - responseObserver.onCompleted(); - } - } - - private static class ManySleeps extends GreeterGrpc.GreeterImplBase - implements RestateBlockingService { - - @Override - public void greet(GreetingRequest request, StreamObserver responseObserver) { - RestateContext ctx = restateContext(); - List> collectedAwaitables = new ArrayList<>(); - - for (int i = 0; i < 10; i++) { - collectedAwaitables.add(ctx.timer(Duration.ofMillis(1000))); - } - - Awaitable.all( - collectedAwaitables.get(0), - collectedAwaitables.get(1), - collectedAwaitables.subList(2, collectedAwaitables.size()).toArray(Awaitable[]::new)) - .await(); - - responseObserver.onNext(GreetingResponse.newBuilder().build()); - responseObserver.onCompleted(); - } - } + protected abstract BindableService manySleeps(); @Override - Stream definitions() { + protected Stream definitions() { return Stream.of( - testInvocation(new SleepGreeter(), GreeterGrpc.getGreetMethod()) + testInvocation(this::sleepGreeter, GreeterGrpc.getGreetMethod()) .withInput(startMessage(1), inputMessage(GreetingRequest.newBuilder().setName("Till"))) .usingAllThreadingModels() .assertingOutput( @@ -79,7 +40,7 @@ Stream definitions() { assertThat(messageLites.get(1)).isInstanceOf(Protocol.SuspensionMessage.class); }) .named("Sleep 1000 ms not completed"), - testInvocation(new SleepGreeter(), GreeterGrpc.getGreetMethod()) + testInvocation(this::sleepGreeter, GreeterGrpc.getGreetMethod()) .withInput( startMessage(2), inputMessage(GreetingRequest.newBuilder().setName("Till")), @@ -91,7 +52,7 @@ Stream definitions() { .expectingOutput( outputMessage(GreetingResponse.newBuilder().setMessage("Hello").build())) .named("Sleep 1000 ms sleep completed"), - testInvocation(new SleepGreeter(), GreeterGrpc.getGreetMethod()) + testInvocation(this::sleepGreeter, GreeterGrpc.getGreetMethod()) .withInput( startMessage(2), inputMessage(GreetingRequest.newBuilder().setName("Till")), @@ -101,7 +62,7 @@ Stream definitions() { .usingAllThreadingModels() .expectingOutput(suspensionMessage(1)) .named("Sleep 1000 ms still sleeping"), - testInvocation(new ManySleeps(), GreeterGrpc.getGreetMethod()) + testInvocation(this::manySleeps, GreeterGrpc.getGreetMethod()) .withInput( Stream.concat( Stream.of( diff --git a/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/StateMachineFailuresTest.java b/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/StateMachineFailuresTest.java deleted file mode 100644 index 1770da45..00000000 --- a/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/StateMachineFailuresTest.java +++ /dev/null @@ -1,122 +0,0 @@ -package dev.restate.sdk.core.impl; - -import static dev.restate.sdk.core.impl.AssertUtils.containsOnly; -import static dev.restate.sdk.core.impl.AssertUtils.errorMessageStartingWith; -import static dev.restate.sdk.core.impl.CoreTestRunner.TestCaseBuilder.testInvocation; -import static dev.restate.sdk.core.impl.ProtoUtils.*; - -import dev.restate.generated.sdk.java.Java; -import dev.restate.sdk.blocking.RestateBlockingService; -import dev.restate.sdk.core.StateKey; -import dev.restate.sdk.core.TypeTag; -import dev.restate.sdk.core.impl.testservices.GreeterGrpc; -import dev.restate.sdk.core.impl.testservices.GreetingRequest; -import dev.restate.sdk.core.impl.testservices.GreetingResponse; -import io.grpc.stub.StreamObserver; -import java.nio.charset.StandardCharsets; -import java.util.stream.Stream; - -class StateMachineFailuresTest extends CoreTestRunner { - - private static class GetState extends GreeterGrpc.GreeterImplBase - implements RestateBlockingService { - - private static final StateKey STATE = - StateKey.of( - "STATE", - TypeTag.using( - i -> Integer.toString(i).getBytes(StandardCharsets.UTF_8), - b -> Integer.parseInt(new String(b, StandardCharsets.UTF_8)))); - - @Override - public void greet(GreetingRequest request, StreamObserver responseObserver) { - restateContext().get(STATE); - responseObserver.onNext(greetingResponse("Francesco")); - responseObserver.onCompleted(); - } - } - - private abstract static class SideEffectFailure extends GreeterGrpc.GreeterImplBase - implements RestateBlockingService { - private final TypeTag typeTag; - - private SideEffectFailure(TypeTag typeTag) { - this.typeTag = typeTag; - } - - @Override - public void greet(GreetingRequest request, StreamObserver responseObserver) { - restateContext().sideEffect(typeTag, () -> 0); - - responseObserver.onNext(greetingResponse("Francesco")); - responseObserver.onCompleted(); - } - } - - private static class EndSideEffectSerializationFailure extends SideEffectFailure { - - private static final TypeTag INTEGER_TYPE_TAG = - TypeTag.using( - i -> { - throw new IllegalStateException("Cannot serialize integer"); - }, - b -> Integer.parseInt(new String(b, StandardCharsets.UTF_8))); - - private EndSideEffectSerializationFailure() { - super(INTEGER_TYPE_TAG); - } - } - - private static class EndSideEffectDeserializationFailure extends SideEffectFailure { - - private static final TypeTag INTEGER_TYPE_TAG = - TypeTag.using( - i -> Integer.toString(i).getBytes(StandardCharsets.UTF_8), - b -> { - throw new IllegalStateException("Cannot deserialize integer"); - }); - - private EndSideEffectDeserializationFailure() { - super(INTEGER_TYPE_TAG); - } - } - - @Override - Stream definitions() { - return Stream.of( - testInvocation(new GetState(), GreeterGrpc.getGreetMethod()) - .withInput( - startMessage(2), - inputMessage(GreetingRequest.newBuilder().setName("Till")), - getStateMessage("Something")) - .usingAllThreadingModels() - .assertingOutput( - containsOnly( - AssertUtils.protocolExceptionErrorMessage( - ProtocolException.JOURNAL_MISMATCH_CODE))), - testInvocation(new GetState(), GreeterGrpc.getGreetMethod()) - .withInput( - startMessage(2), - inputMessage(GreetingRequest.newBuilder().setName("Till")), - getStateMessage("STATE", "This is not an integer")) - .usingAllThreadingModels() - .assertingOutput( - containsOnly( - errorMessageStartingWith(NumberFormatException.class.getCanonicalName()))), - testInvocation(new EndSideEffectSerializationFailure(), GreeterGrpc.getGreetMethod()) - .withInput(startMessage(1), inputMessage(GreetingRequest.newBuilder().setName("Till"))) - .usingAllThreadingModels() - .assertingOutput( - containsOnly( - errorMessageStartingWith(IllegalStateException.class.getCanonicalName()))), - testInvocation(new EndSideEffectDeserializationFailure(), GreeterGrpc.getGreetMethod()) - .withInput( - startMessage(2), - inputMessage(GreetingRequest.newBuilder().setName("Till")), - Java.SideEffectEntryMessage.newBuilder()) - .usingAllThreadingModels() - .assertingOutput( - containsOnly( - errorMessageStartingWith(IllegalStateException.class.getCanonicalName())))); - } -} diff --git a/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/StateMachineFailuresTestSuite.java b/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/StateMachineFailuresTestSuite.java new file mode 100644 index 00000000..76cf1b87 --- /dev/null +++ b/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/StateMachineFailuresTestSuite.java @@ -0,0 +1,78 @@ +package dev.restate.sdk.core.impl; + +import static dev.restate.sdk.core.impl.AssertUtils.containsOnly; +import static dev.restate.sdk.core.impl.AssertUtils.errorMessageStartingWith; +import static dev.restate.sdk.core.impl.CoreTestRunner.TestCaseBuilder.testInvocation; +import static dev.restate.sdk.core.impl.ProtoUtils.*; + +import dev.restate.generated.sdk.java.Java; +import dev.restate.sdk.core.TypeTag; +import dev.restate.sdk.core.impl.testservices.GreeterGrpc; +import dev.restate.sdk.core.impl.testservices.GreetingRequest; +import io.grpc.BindableService; +import java.nio.charset.StandardCharsets; +import java.util.stream.Stream; + +public abstract class StateMachineFailuresTestSuite extends CoreTestRunner { + + protected abstract BindableService getState(); + + protected abstract BindableService sideEffectFailure(TypeTag typeTag); + + private static final TypeTag FAILING_SERIALIZATION_INTEGER_TYPE_TAG = + TypeTag.using( + i -> { + throw new IllegalStateException("Cannot serialize integer"); + }, + b -> Integer.parseInt(new String(b, StandardCharsets.UTF_8))); + + private static final TypeTag FAILING_DESERIALIZATION_INTEGER_TYPE_TAG = + TypeTag.using( + i -> Integer.toString(i).getBytes(StandardCharsets.UTF_8), + b -> { + throw new IllegalStateException("Cannot deserialize integer"); + }); + + @Override + protected Stream definitions() { + return Stream.of( + testInvocation(this::getState, GreeterGrpc.getGreetMethod()) + .withInput( + startMessage(2), + inputMessage(GreetingRequest.newBuilder().setName("Till")), + getStateMessage("Something")) + .usingAllThreadingModels() + .assertingOutput( + containsOnly( + AssertUtils.protocolExceptionErrorMessage( + ProtocolException.JOURNAL_MISMATCH_CODE))), + testInvocation(this::getState, GreeterGrpc.getGreetMethod()) + .withInput( + startMessage(2), + inputMessage(GreetingRequest.newBuilder().setName("Till")), + getStateMessage("STATE", "This is not an integer")) + .usingAllThreadingModels() + .assertingOutput( + containsOnly( + errorMessageStartingWith(NumberFormatException.class.getCanonicalName()))), + testInvocation( + () -> this.sideEffectFailure(FAILING_SERIALIZATION_INTEGER_TYPE_TAG), + GreeterGrpc.getGreetMethod()) + .withInput(startMessage(1), inputMessage(GreetingRequest.newBuilder().setName("Till"))) + .usingAllThreadingModels() + .assertingOutput( + containsOnly( + errorMessageStartingWith(IllegalStateException.class.getCanonicalName()))), + testInvocation( + () -> this.sideEffectFailure(FAILING_DESERIALIZATION_INTEGER_TYPE_TAG), + GreeterGrpc.getGreetMethod()) + .withInput( + startMessage(2), + inputMessage(GreetingRequest.newBuilder().setName("Till")), + Java.SideEffectEntryMessage.newBuilder()) + .usingAllThreadingModels() + .assertingOutput( + containsOnly( + errorMessageStartingWith(IllegalStateException.class.getCanonicalName())))); + } +} diff --git a/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/UserFailuresTest.java b/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/UserFailuresTest.java deleted file mode 100644 index 93824ba4..00000000 --- a/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/UserFailuresTest.java +++ /dev/null @@ -1,122 +0,0 @@ -package dev.restate.sdk.core.impl; - -import static dev.restate.sdk.core.impl.AssertUtils.containsOnlyExactErrorMessage; -import static dev.restate.sdk.core.impl.CoreTestRunner.TestCaseBuilder.testInvocation; -import static dev.restate.sdk.core.impl.ProtoUtils.*; - -import dev.restate.generated.sdk.java.Java; -import dev.restate.sdk.blocking.RestateBlockingService; -import dev.restate.sdk.core.impl.testservices.GreeterGrpc; -import dev.restate.sdk.core.impl.testservices.GreetingRequest; -import dev.restate.sdk.core.impl.testservices.GreetingResponse; -import io.grpc.Status; -import io.grpc.StatusRuntimeException; -import io.grpc.stub.StreamObserver; -import java.util.stream.Stream; - -class UserFailuresTest extends CoreTestRunner { - - private static final Status MY_ERROR = Status.INTERNAL.withDescription("my error"); - - private static class ThrowIllegalStateException extends GreeterGrpc.GreeterImplBase { - @Override - public void greet(GreetingRequest request, StreamObserver responseObserver) { - throw new IllegalStateException("Whatever"); - } - } - - private static class ResponseObserverOnErrorIllegalStateException - extends GreeterGrpc.GreeterImplBase { - @Override - public void greet(GreetingRequest request, StreamObserver responseObserver) { - responseObserver.onError(new IllegalStateException("Whatever")); - } - } - - private static class SideEffectThrowIllegalStateException extends GreeterGrpc.GreeterImplBase - implements RestateBlockingService { - @Override - public void greet(GreetingRequest request, StreamObserver responseObserver) { - restateContext() - .sideEffect( - () -> { - throw new IllegalStateException("Whatever"); - }); - } - } - - private static class ThrowUnknownStatusRuntimeException extends GreeterGrpc.GreeterImplBase { - @Override - public void greet(GreetingRequest request, StreamObserver responseObserver) { - throw new StatusRuntimeException(Status.UNKNOWN.withDescription("Whatever")); - } - } - - private static class ThrowStatusRuntimeException extends GreeterGrpc.GreeterImplBase { - @Override - public void greet(GreetingRequest request, StreamObserver responseObserver) { - throw new StatusRuntimeException(MY_ERROR); - } - } - - private static class ResponseObserverOnErrorStatusRuntimeException - extends GreeterGrpc.GreeterImplBase { - @Override - public void greet(GreetingRequest request, StreamObserver responseObserver) { - responseObserver.onError(new StatusRuntimeException(MY_ERROR)); - } - } - - private static class SideEffectThrowStatusRuntimeException extends GreeterGrpc.GreeterImplBase - implements RestateBlockingService { - @Override - public void greet(GreetingRequest request, StreamObserver responseObserver) { - restateContext() - .sideEffect( - () -> { - throw new StatusRuntimeException(MY_ERROR); - }); - } - } - - @Override - Stream definitions() { - return Stream.of( - // Cases returning ErrorMessage - testInvocation(new ThrowIllegalStateException(), GreeterGrpc.getGreetMethod()) - .withInput(startMessage(1), inputMessage(GreetingRequest.getDefaultInstance())) - .usingAllThreadingModels() - .assertingOutput(containsOnlyExactErrorMessage(new IllegalStateException("Whatever"))), - testInvocation(new SideEffectThrowIllegalStateException(), GreeterGrpc.getGreetMethod()) - .withInput(startMessage(1), inputMessage(GreetingRequest.getDefaultInstance())) - .usingAllThreadingModels() - .assertingOutput(containsOnlyExactErrorMessage(new IllegalStateException("Whatever"))), - - // Cases completing the invocation with OutputStreamEntry.failure - testInvocation(new ThrowStatusRuntimeException(), GreeterGrpc.getGreetMethod()) - .withInput(startMessage(1), inputMessage(GreetingRequest.getDefaultInstance())) - .usingAllThreadingModels() - .expectingOutput(outputMessage(MY_ERROR)), - testInvocation(new ThrowUnknownStatusRuntimeException(), GreeterGrpc.getGreetMethod()) - .withInput(startMessage(1), inputMessage(GreetingRequest.getDefaultInstance())) - .usingAllThreadingModels() - .expectingOutput(outputMessage(Status.UNKNOWN.withDescription("Whatever"))), - testInvocation( - new ResponseObserverOnErrorStatusRuntimeException(), GreeterGrpc.getGreetMethod()) - .withInput(startMessage(1), inputMessage(GreetingRequest.getDefaultInstance())) - .usingAllThreadingModels() - .expectingOutput(outputMessage(MY_ERROR)), - testInvocation( - new ResponseObserverOnErrorIllegalStateException(), GreeterGrpc.getGreetMethod()) - .withInput(startMessage(1), inputMessage(GreetingRequest.getDefaultInstance())) - .usingAllThreadingModels() - .expectingOutput(outputMessage(Status.UNKNOWN)), - testInvocation(new SideEffectThrowStatusRuntimeException(), GreeterGrpc.getGreetMethod()) - .withInput(startMessage(1), inputMessage(GreetingRequest.getDefaultInstance())) - .usingAllThreadingModels() - .expectingOutput( - Java.SideEffectEntryMessage.newBuilder() - .setFailure(Util.toProtocolFailure(MY_ERROR)), - outputMessage(MY_ERROR))); - } -} diff --git a/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/UserFailuresTestSuite.java b/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/UserFailuresTestSuite.java new file mode 100644 index 00000000..645b905d --- /dev/null +++ b/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/UserFailuresTestSuite.java @@ -0,0 +1,77 @@ +package dev.restate.sdk.core.impl; + +import static dev.restate.sdk.core.impl.AssertUtils.containsOnlyExactErrorMessage; +import static dev.restate.sdk.core.impl.CoreTestRunner.TestCaseBuilder.testInvocation; +import static dev.restate.sdk.core.impl.ProtoUtils.*; + +import dev.restate.generated.sdk.java.Java; +import dev.restate.sdk.core.impl.testservices.GreeterGrpc; +import dev.restate.sdk.core.impl.testservices.GreetingRequest; +import io.grpc.BindableService; +import io.grpc.Status; +import java.util.stream.Stream; + +public abstract class UserFailuresTestSuite extends CoreTestRunner { + + public static final Status INTERNAL_MY_ERROR = Status.INTERNAL.withDescription("my error"); + + public static final Status UNKNOWN_MY_ERROR = Status.UNKNOWN.withDescription("Whatever"); + + protected abstract BindableService throwIllegalStateException(); + + protected abstract BindableService sideEffectThrowIllegalStateException(); + + protected abstract BindableService throwStatusRuntimeException(Status status); + + protected abstract BindableService sideEffectThrowStatusRuntimeException(Status status); + + @Override + protected Stream definitions() { + return Stream.of( + // Cases returning ErrorMessage + testInvocation(this::throwIllegalStateException, GreeterGrpc.getGreetMethod()) + .withInput(startMessage(1), inputMessage(GreetingRequest.getDefaultInstance())) + .usingAllThreadingModels() + .assertingOutput(containsOnlyExactErrorMessage(new IllegalStateException("Whatever"))), + testInvocation(this::sideEffectThrowIllegalStateException, GreeterGrpc.getGreetMethod()) + .withInput(startMessage(1), inputMessage(GreetingRequest.getDefaultInstance())) + .usingAllThreadingModels() + .assertingOutput(containsOnlyExactErrorMessage(new IllegalStateException("Whatever"))), + + // Cases completing the invocation with OutputStreamEntry.failure + testInvocation( + () -> this.throwStatusRuntimeException(INTERNAL_MY_ERROR), + GreeterGrpc.getGreetMethod()) + .withInput(startMessage(1), inputMessage(GreetingRequest.getDefaultInstance())) + .usingAllThreadingModels() + .expectingOutput(outputMessage(INTERNAL_MY_ERROR)) + .named("With internal error"), + testInvocation( + () -> this.throwStatusRuntimeException(UNKNOWN_MY_ERROR), + GreeterGrpc.getGreetMethod()) + .withInput(startMessage(1), inputMessage(GreetingRequest.getDefaultInstance())) + .usingAllThreadingModels() + .expectingOutput(outputMessage(UNKNOWN_MY_ERROR)) + .named("With unknown error"), + testInvocation( + () -> this.sideEffectThrowStatusRuntimeException(INTERNAL_MY_ERROR), + GreeterGrpc.getGreetMethod()) + .withInput(startMessage(1), inputMessage(GreetingRequest.getDefaultInstance())) + .usingAllThreadingModels() + .expectingOutput( + Java.SideEffectEntryMessage.newBuilder() + .setFailure(Util.toProtocolFailure(INTERNAL_MY_ERROR)), + outputMessage(INTERNAL_MY_ERROR)) + .named("With internal error"), + testInvocation( + () -> this.sideEffectThrowStatusRuntimeException(UNKNOWN_MY_ERROR), + GreeterGrpc.getGreetMethod()) + .withInput(startMessage(1), inputMessage(GreetingRequest.getDefaultInstance())) + .usingAllThreadingModels() + .expectingOutput( + Java.SideEffectEntryMessage.newBuilder() + .setFailure(Util.toProtocolFailure(UNKNOWN_MY_ERROR)), + outputMessage(UNKNOWN_MY_ERROR)) + .named("With unknown error")); + } +} diff --git a/sdk-java-blocking/build.gradle.kts b/sdk-java-blocking/build.gradle.kts index db5a789f..f6b95ef7 100644 --- a/sdk-java-blocking/build.gradle.kts +++ b/sdk-java-blocking/build.gradle.kts @@ -1,3 +1,5 @@ +import com.google.protobuf.gradle.id + // Without these suppressions version catalog usage here and in other build // files is marked red by IntelliJ: // https://youtrack.jetbrains.com/issue/KTIJ-19369. @@ -12,7 +14,31 @@ plugins { `maven-publish` } -dependencies { api(project(":sdk-core")) } +dependencies { + api(project(":sdk-core")) + + testCompileOnly(coreLibs.javax.annotation.api) + + testImplementation(project(":sdk-core-impl")) + testImplementation(testingLibs.junit.jupiter) + testImplementation(testingLibs.assertj) + testImplementation(coreLibs.protobuf.java) + testImplementation(coreLibs.grpc.stub) + testImplementation(coreLibs.grpc.protobuf) + testImplementation(coreLibs.log4j.core) + + // Import test suites from sdk-core-impl + testImplementation(project(":sdk-core-impl", "testArchive")) + testProtobuf(project(":sdk-core-impl", "testArchive")) +} + +protobuf { + plugins { + id("grpc") { artifact = "io.grpc:protoc-gen-grpc-java:${coreLibs.versions.grpc.get()}" } + } + + generateProtoTasks { ofSourceSet("test").forEach { it.plugins { id("grpc") } } } +} publishing { publications { diff --git a/sdk-java-blocking/src/main/java/dev/restate/sdk/blocking/RestateContext.java b/sdk-java-blocking/src/main/java/dev/restate/sdk/blocking/RestateContext.java index 1d555146..8e27f694 100644 --- a/sdk-java-blocking/src/main/java/dev/restate/sdk/blocking/RestateContext.java +++ b/sdk-java-blocking/src/main/java/dev/restate/sdk/blocking/RestateContext.java @@ -125,7 +125,7 @@ default T sideEffect(Class clazz, Supplier action) { /** Like {@link #sideEffect(TypeTag, Supplier)}, but without returning a value. */ default void sideEffect(Runnable runnable) { sideEffect( - Void.class, + TypeTag.VOID, () -> { runnable.run(); return null; diff --git a/sdk-java-blocking/src/test/java/dev/restate/sdk/blocking/AwakeableIdTest.java b/sdk-java-blocking/src/test/java/dev/restate/sdk/blocking/AwakeableIdTest.java new file mode 100644 index 00000000..a3c03f5f --- /dev/null +++ b/sdk-java-blocking/src/test/java/dev/restate/sdk/blocking/AwakeableIdTest.java @@ -0,0 +1,29 @@ +package dev.restate.sdk.blocking; + +import static dev.restate.sdk.core.impl.ProtoUtils.greetingResponse; + +import dev.restate.sdk.core.TypeTag; +import dev.restate.sdk.core.impl.AwakeableIdTestSuite; +import dev.restate.sdk.core.impl.testservices.GreeterGrpc; +import dev.restate.sdk.core.impl.testservices.GreetingRequest; +import dev.restate.sdk.core.impl.testservices.GreetingResponse; +import io.grpc.BindableService; +import io.grpc.stub.StreamObserver; + +class AwakeableIdTest extends AwakeableIdTestSuite { + + private static class ReturnAwakeableId extends GreeterGrpc.GreeterImplBase + implements RestateBlockingService { + + @Override + public void greet(GreetingRequest request, StreamObserver responseObserver) { + String id = restateContext().awakeable(TypeTag.STRING_UTF8).id(); + responseObserver.onNext(greetingResponse(id)); + responseObserver.onCompleted(); + } + } + + protected BindableService returnAwakeableId() { + return new ReturnAwakeableId(); + } +} diff --git a/sdk-java-blocking/src/test/java/dev/restate/sdk/blocking/DeferredTest.java b/sdk-java-blocking/src/test/java/dev/restate/sdk/blocking/DeferredTest.java new file mode 100644 index 00000000..0c9ef1e9 --- /dev/null +++ b/sdk-java-blocking/src/test/java/dev/restate/sdk/blocking/DeferredTest.java @@ -0,0 +1,212 @@ +package dev.restate.sdk.blocking; + +import static dev.restate.sdk.core.impl.ProtoUtils.greetingRequest; +import static dev.restate.sdk.core.impl.ProtoUtils.greetingResponse; + +import dev.restate.sdk.core.StateKey; +import dev.restate.sdk.core.TypeTag; +import dev.restate.sdk.core.impl.DeferredTestSuite; +import dev.restate.sdk.core.impl.testservices.GreeterGrpc; +import dev.restate.sdk.core.impl.testservices.GreetingRequest; +import dev.restate.sdk.core.impl.testservices.GreetingResponse; +import io.grpc.BindableService; +import io.grpc.stub.StreamObserver; +import java.time.Duration; +import java.util.concurrent.TimeoutException; + +class DeferredTest extends DeferredTestSuite { + + private static class ReverseAwaitOrder extends GreeterGrpc.GreeterImplBase + implements RestateBlockingService { + @Override + public void greet(GreetingRequest request, StreamObserver responseObserver) { + RestateContext ctx = restateContext(); + + Awaitable a1 = + ctx.call(GreeterGrpc.getGreetMethod(), greetingRequest("Francesco")); + Awaitable a2 = + ctx.call(GreeterGrpc.getGreetMethod(), greetingRequest("Till")); + + String a2Res = a2.await().getMessage(); + ctx.set(StateKey.of("A2", TypeTag.STRING_UTF8), a2Res); + + String a1Res = a1.await().getMessage(); + + responseObserver.onNext(greetingResponse(a1Res + "-" + a2Res)); + responseObserver.onCompleted(); + } + } + + @Override + protected BindableService reverseAwaitOrder() { + return new ReverseAwaitOrder(); + } + + private static class AwaitTwiceTheSameAwaitable extends GreeterGrpc.GreeterImplBase + implements RestateBlockingService { + @Override + public void greet(GreetingRequest request, StreamObserver responseObserver) { + RestateContext ctx = restateContext(); + + Awaitable a = + ctx.call(GreeterGrpc.getGreetMethod(), greetingRequest("Francesco")); + + responseObserver.onNext( + greetingResponse(a.await().getMessage() + "-" + a.await().getMessage())); + responseObserver.onCompleted(); + } + } + + @Override + protected BindableService awaitTwiceTheSameAwaitable() { + return new AwaitTwiceTheSameAwaitable(); + } + + private static class AwaitAll extends GreeterGrpc.GreeterImplBase + implements RestateBlockingService { + @Override + public void greet(GreetingRequest request, StreamObserver responseObserver) { + RestateContext ctx = restateContext(); + + Awaitable a1 = + ctx.call(GreeterGrpc.getGreetMethod(), greetingRequest("Francesco")); + Awaitable a2 = + ctx.call(GreeterGrpc.getGreetMethod(), greetingRequest("Till")); + + Awaitable.all(a1, a2).await(); + + responseObserver.onNext( + greetingResponse(a1.await().getMessage() + "-" + a2.await().getMessage())); + responseObserver.onCompleted(); + } + } + + @Override + protected BindableService awaitAll() { + return new AwaitAll(); + } + + private static class AwaitAny extends GreeterGrpc.GreeterImplBase + implements RestateBlockingService { + @Override + public void greet(GreetingRequest request, StreamObserver responseObserver) { + RestateContext ctx = restateContext(); + + Awaitable a1 = + ctx.call(GreeterGrpc.getGreetMethod(), greetingRequest("Francesco")); + Awaitable a2 = + ctx.call(GreeterGrpc.getGreetMethod(), greetingRequest("Till")); + + GreetingResponse res = (GreetingResponse) Awaitable.any(a1, a2).await(); + + responseObserver.onNext(res); + responseObserver.onCompleted(); + } + } + + @Override + protected BindableService awaitAny() { + return new AwaitAny(); + } + + private static class CombineAnyWithAll extends GreeterGrpc.GreeterImplBase + implements RestateBlockingService { + @Override + public void greet(GreetingRequest request, StreamObserver responseObserver) { + RestateContext ctx = restateContext(); + + Awaitable a1 = ctx.awakeable(TypeTag.STRING_UTF8); + Awaitable a2 = ctx.awakeable(TypeTag.STRING_UTF8); + Awaitable a3 = ctx.awakeable(TypeTag.STRING_UTF8); + Awaitable a4 = ctx.awakeable(TypeTag.STRING_UTF8); + + Awaitable a12 = Awaitable.any(a1, a2); + Awaitable a23 = Awaitable.any(a2, a3); + Awaitable a34 = Awaitable.any(a3, a4); + Awaitable.all(a12, a23, a34).await(); + + responseObserver.onNext(greetingResponse(a12.await() + (String) a23.await() + a34.await())); + responseObserver.onCompleted(); + } + } + + @Override + protected BindableService combineAnyWithAll() { + return new CombineAnyWithAll(); + } + + private static class AwaitAnyIndex extends GreeterGrpc.GreeterImplBase + implements RestateBlockingService { + @Override + public void greet(GreetingRequest request, StreamObserver responseObserver) { + RestateContext ctx = restateContext(); + + Awaitable a1 = ctx.awakeable(TypeTag.STRING_UTF8); + Awaitable a2 = ctx.awakeable(TypeTag.STRING_UTF8); + Awaitable a3 = ctx.awakeable(TypeTag.STRING_UTF8); + Awaitable a4 = ctx.awakeable(TypeTag.STRING_UTF8); + + responseObserver.onNext( + greetingResponse( + String.valueOf(Awaitable.any(a1, Awaitable.all(a2, a3), a4).awaitIndex()))); + responseObserver.onCompleted(); + } + } + + @Override + protected BindableService awaitAnyIndex() { + return new AwaitAnyIndex(); + } + + private static class AwaitOnAlreadyResolvedAwaitables extends GreeterGrpc.GreeterImplBase + implements RestateBlockingService { + @Override + public void greet(GreetingRequest request, StreamObserver responseObserver) { + RestateContext ctx = restateContext(); + + Awaitable a1 = ctx.awakeable(TypeTag.STRING_UTF8); + Awaitable a2 = ctx.awakeable(TypeTag.STRING_UTF8); + + Awaitable a12 = Awaitable.all(a1, a2); + Awaitable a12and1 = Awaitable.all(a12, a1); + Awaitable a121and12 = Awaitable.all(a12and1, a12); + + a12and1.await(); + a121and12.await(); + + responseObserver.onNext(greetingResponse(a1.await() + a2.await())); + responseObserver.onCompleted(); + } + } + + @Override + protected BindableService awaitOnAlreadyResolvedAwaitables() { + return new AwaitOnAlreadyResolvedAwaitables(); + } + + private static class AwaitWithTimeout extends GreeterGrpc.GreeterImplBase + implements RestateBlockingService { + @Override + public void greet(GreetingRequest request, StreamObserver responseObserver) { + RestateContext ctx = restateContext(); + + Awaitable call = + ctx.call(GreeterGrpc.getGreetMethod(), greetingRequest("Francesco")); + + String result; + try { + result = call.await(Duration.ofDays(1)).getMessage(); + } catch (TimeoutException e) { + result = "timeout"; + } + + responseObserver.onNext(greetingResponse(result)); + responseObserver.onCompleted(); + } + } + + @Override + protected BindableService awaitWithTimeout() { + return new AwaitWithTimeout(); + } +} diff --git a/sdk-java-blocking/src/test/java/dev/restate/sdk/blocking/EagerStateTest.java b/sdk-java-blocking/src/test/java/dev/restate/sdk/blocking/EagerStateTest.java new file mode 100644 index 00000000..7a2aa37f --- /dev/null +++ b/sdk-java-blocking/src/test/java/dev/restate/sdk/blocking/EagerStateTest.java @@ -0,0 +1,93 @@ +package dev.restate.sdk.blocking; + +import static org.assertj.core.api.Assertions.assertThat; + +import dev.restate.sdk.core.StateKey; +import dev.restate.sdk.core.TypeTag; +import dev.restate.sdk.core.impl.EagerStateTestSuite; +import dev.restate.sdk.core.impl.testservices.GreeterGrpc; +import dev.restate.sdk.core.impl.testservices.GreetingRequest; +import dev.restate.sdk.core.impl.testservices.GreetingResponse; +import io.grpc.BindableService; +import io.grpc.stub.StreamObserver; + +class EagerStateTest extends EagerStateTestSuite { + + private static class GetEmpty extends GreeterGrpc.GreeterImplBase + implements RestateBlockingService { + @Override + public void greet(GreetingRequest request, StreamObserver responseObserver) { + RestateContext ctx = restateContext(); + + boolean stateIsEmpty = ctx.get(StateKey.of("STATE", TypeTag.STRING_UTF8)).isEmpty(); + + responseObserver.onNext( + GreetingResponse.newBuilder().setMessage(String.valueOf(stateIsEmpty)).build()); + responseObserver.onCompleted(); + } + } + + @Override + protected BindableService getEmpty() { + return new GetEmpty(); + } + + private static class Get extends GreeterGrpc.GreeterImplBase implements RestateBlockingService { + @Override + public void greet(GreetingRequest request, StreamObserver responseObserver) { + RestateContext ctx = restateContext(); + + String state = ctx.get(StateKey.of("STATE", TypeTag.STRING_UTF8)).get(); + + responseObserver.onNext(GreetingResponse.newBuilder().setMessage(state).build()); + responseObserver.onCompleted(); + } + } + + @Override + protected BindableService get() { + return new Get(); + } + + private static class GetAppendAndGet extends GreeterGrpc.GreeterImplBase + implements RestateBlockingService { + @Override + public void greet(GreetingRequest request, StreamObserver responseObserver) { + RestateContext ctx = restateContext(); + + String oldState = ctx.get(StateKey.of("STATE", TypeTag.STRING_UTF8)).get(); + ctx.set(StateKey.of("STATE", TypeTag.STRING_UTF8), oldState + request.getName()); + + String newState = ctx.get(StateKey.of("STATE", TypeTag.STRING_UTF8)).get(); + + responseObserver.onNext(GreetingResponse.newBuilder().setMessage(newState).build()); + responseObserver.onCompleted(); + } + } + + @Override + protected BindableService getAppendAndGet() { + return new GetAppendAndGet(); + } + + private static class GetClearAndGet extends GreeterGrpc.GreeterImplBase + implements RestateBlockingService { + @Override + public void greet(GreetingRequest request, StreamObserver responseObserver) { + RestateContext ctx = restateContext(); + + String oldState = ctx.get(StateKey.of("STATE", TypeTag.STRING_UTF8)).get(); + + ctx.clear(StateKey.of("STATE", TypeTag.STRING_UTF8)); + assertThat(ctx.get(StateKey.of("STATE", TypeTag.STRING_UTF8))).isEmpty(); + + responseObserver.onNext(GreetingResponse.newBuilder().setMessage(oldState).build()); + responseObserver.onCompleted(); + } + } + + @Override + protected BindableService getClearAndGet() { + return new GetClearAndGet(); + } +} diff --git a/sdk-java-blocking/src/test/java/dev/restate/sdk/blocking/GetAndSetStateTest.java b/sdk-java-blocking/src/test/java/dev/restate/sdk/blocking/GetAndSetStateTest.java new file mode 100644 index 00000000..840cfa3e --- /dev/null +++ b/sdk-java-blocking/src/test/java/dev/restate/sdk/blocking/GetAndSetStateTest.java @@ -0,0 +1,62 @@ +package dev.restate.sdk.blocking; + +import static dev.restate.sdk.core.impl.ProtoUtils.greetingResponse; + +import dev.restate.sdk.core.StateKey; +import dev.restate.sdk.core.TypeTag; +import dev.restate.sdk.core.impl.GetAndSetStateTestSuite; +import dev.restate.sdk.core.impl.testservices.GreeterGrpc; +import dev.restate.sdk.core.impl.testservices.GreetingRequest; +import dev.restate.sdk.core.impl.testservices.GreetingResponse; +import io.grpc.BindableService; +import io.grpc.stub.StreamObserver; + +class GetAndSetStateTest extends GetAndSetStateTestSuite { + + private static class GetAndSetGreeter extends GreeterGrpc.GreeterImplBase + implements RestateBlockingService { + @Override + public void greet(GreetingRequest request, StreamObserver responseObserver) { + RestateContext ctx = restateContext(); + + String state = ctx.get(StateKey.of("STATE", TypeTag.STRING_UTF8)).get(); + + ctx.set(StateKey.of("STATE", TypeTag.STRING_UTF8), request.getName()); + + responseObserver.onNext(GreetingResponse.newBuilder().setMessage("Hello " + state).build()); + responseObserver.onCompleted(); + } + } + + @Override + protected BindableService getAndSetGreeter() { + return new GetAndSetGreeter(); + } + + private static class SetNullState extends GreeterGrpc.GreeterImplBase + implements RestateBlockingService { + @Override + public void greet(GreetingRequest request, StreamObserver responseObserver) { + restateContext() + .set( + StateKey.of( + "STATE", + TypeTag.using( + l -> { + throw new IllegalStateException("Unexpected call to serde fn"); + }, + l -> { + throw new IllegalStateException("Unexpected call to serde fn"); + })), + null); + + responseObserver.onNext(greetingResponse("")); + responseObserver.onCompleted(); + } + } + + @Override + protected BindableService setNullState() { + return new SetNullState(); + } +} diff --git a/sdk-java-blocking/src/test/java/dev/restate/sdk/blocking/GetStateTest.java b/sdk-java-blocking/src/test/java/dev/restate/sdk/blocking/GetStateTest.java new file mode 100644 index 00000000..d0ada38e --- /dev/null +++ b/sdk-java-blocking/src/test/java/dev/restate/sdk/blocking/GetStateTest.java @@ -0,0 +1,30 @@ +package dev.restate.sdk.blocking; + +import dev.restate.sdk.core.StateKey; +import dev.restate.sdk.core.TypeTag; +import dev.restate.sdk.core.impl.GetStateTestSuite; +import dev.restate.sdk.core.impl.testservices.GreeterGrpc; +import dev.restate.sdk.core.impl.testservices.GreetingRequest; +import dev.restate.sdk.core.impl.testservices.GreetingResponse; +import io.grpc.BindableService; +import io.grpc.stub.StreamObserver; + +class GetStateTest extends GetStateTestSuite { + + private static class GetStateGreeter extends GreeterGrpc.GreeterImplBase + implements RestateBlockingService { + @Override + public void greet(GreetingRequest request, StreamObserver responseObserver) { + String state = + restateContext().get(StateKey.of("STATE", TypeTag.STRING_UTF8)).orElse("Unknown"); + + responseObserver.onNext(GreetingResponse.newBuilder().setMessage("Hello " + state).build()); + responseObserver.onCompleted(); + } + } + + @Override + protected BindableService getStateGreeter() { + return new GetStateGreeter(); + } +} diff --git a/sdk-java-blocking/src/test/java/dev/restate/sdk/blocking/InvocationIdTest.java b/sdk-java-blocking/src/test/java/dev/restate/sdk/blocking/InvocationIdTest.java new file mode 100644 index 00000000..58b1fbef --- /dev/null +++ b/sdk-java-blocking/src/test/java/dev/restate/sdk/blocking/InvocationIdTest.java @@ -0,0 +1,29 @@ +package dev.restate.sdk.blocking; + +import static dev.restate.sdk.core.impl.ProtoUtils.greetingResponse; + +import dev.restate.sdk.core.InvocationId; +import dev.restate.sdk.core.impl.InvocationIdTestSuite; +import dev.restate.sdk.core.impl.testservices.GreeterGrpc; +import dev.restate.sdk.core.impl.testservices.GreetingRequest; +import dev.restate.sdk.core.impl.testservices.GreetingResponse; +import io.grpc.BindableService; +import io.grpc.stub.StreamObserver; + +class InvocationIdTest extends InvocationIdTestSuite { + + private static class ReturnInvocationId extends GreeterGrpc.GreeterImplBase + implements RestateBlockingService { + + @Override + public void greet(GreetingRequest request, StreamObserver responseObserver) { + responseObserver.onNext(greetingResponse(InvocationId.current().toString())); + responseObserver.onCompleted(); + } + } + + @Override + protected BindableService returnInvocationId() { + return new ReturnInvocationId(); + } +} diff --git a/sdk-java-blocking/src/test/java/dev/restate/sdk/blocking/OnlyInputAndOutputTest.java b/sdk-java-blocking/src/test/java/dev/restate/sdk/blocking/OnlyInputAndOutputTest.java new file mode 100644 index 00000000..38847332 --- /dev/null +++ b/sdk-java-blocking/src/test/java/dev/restate/sdk/blocking/OnlyInputAndOutputTest.java @@ -0,0 +1,25 @@ +package dev.restate.sdk.blocking; + +import dev.restate.sdk.core.impl.OnlyInputAndOutputTestSuite; +import dev.restate.sdk.core.impl.testservices.GreeterGrpc; +import dev.restate.sdk.core.impl.testservices.GreetingRequest; +import dev.restate.sdk.core.impl.testservices.GreetingResponse; +import io.grpc.BindableService; +import io.grpc.stub.StreamObserver; + +class OnlyInputAndOutputTest extends OnlyInputAndOutputTestSuite { + + private static class NoSyscallsGreeter extends GreeterGrpc.GreeterImplBase { + @Override + public void greet(GreetingRequest request, StreamObserver responseObserver) { + responseObserver.onNext( + GreetingResponse.newBuilder().setMessage("Hello " + request.getName()).build()); + responseObserver.onCompleted(); + } + } + + @Override + protected BindableService noSyscallsGreeter() { + return new NoSyscallsGreeter(); + } +} diff --git a/sdk-java-blocking/src/test/java/dev/restate/sdk/blocking/SideEffectTest.java b/sdk-java-blocking/src/test/java/dev/restate/sdk/blocking/SideEffectTest.java new file mode 100644 index 00000000..e4e4cdf3 --- /dev/null +++ b/sdk-java-blocking/src/test/java/dev/restate/sdk/blocking/SideEffectTest.java @@ -0,0 +1,136 @@ +package dev.restate.sdk.blocking; + +import static dev.restate.sdk.core.impl.ProtoUtils.greetingRequest; + +import dev.restate.sdk.core.TypeTag; +import dev.restate.sdk.core.impl.SideEffectTestSuite; +import dev.restate.sdk.core.impl.testservices.GreeterGrpc; +import dev.restate.sdk.core.impl.testservices.GreetingRequest; +import dev.restate.sdk.core.impl.testservices.GreetingResponse; +import io.grpc.BindableService; +import io.grpc.stub.StreamObserver; +import java.util.Objects; + +class SideEffectTest extends SideEffectTestSuite { + + private static class SideEffect extends GreeterGrpc.GreeterImplBase + implements RestateBlockingService { + + private final String sideEffectOutput; + + SideEffect(String sideEffectOutput) { + this.sideEffectOutput = sideEffectOutput; + } + + @Override + public void greet(GreetingRequest request, StreamObserver responseObserver) { + RestateContext ctx = restateContext(); + + String result = ctx.sideEffect(TypeTag.STRING_UTF8, () -> this.sideEffectOutput); + + responseObserver.onNext(GreetingResponse.newBuilder().setMessage("Hello " + result).build()); + responseObserver.onCompleted(); + } + } + + @Override + protected BindableService sideEffect(String sideEffectOutput) { + return new SideEffect(sideEffectOutput); + } + + private static class ConsecutiveSideEffect extends GreeterGrpc.GreeterImplBase + implements RestateBlockingService { + + private final String sideEffectOutput; + + ConsecutiveSideEffect(String sideEffectOutput) { + this.sideEffectOutput = sideEffectOutput; + } + + @Override + public void greet(GreetingRequest request, StreamObserver responseObserver) { + RestateContext ctx = restateContext(); + + String firstResult = ctx.sideEffect(TypeTag.STRING_UTF8, () -> this.sideEffectOutput); + String secondResult = ctx.sideEffect(TypeTag.STRING_UTF8, firstResult::toUpperCase); + + responseObserver.onNext( + GreetingResponse.newBuilder().setMessage("Hello " + secondResult).build()); + responseObserver.onCompleted(); + } + } + + @Override + protected BindableService consecutiveSideEffect(String sideEffectOutput) { + return new ConsecutiveSideEffect(sideEffectOutput); + } + + private static class CheckContextSwitching extends GreeterGrpc.GreeterImplBase + implements RestateBlockingService { + + @Override + public void greet(GreetingRequest request, StreamObserver responseObserver) { + String currentThread = Thread.currentThread().getName(); + + String sideEffectThread = + restateContext().sideEffect(TypeTag.STRING_UTF8, () -> Thread.currentThread().getName()); + + if (!Objects.equals(currentThread, sideEffectThread)) { + throw new IllegalStateException( + "Current thread and side effect thread do not match: " + + currentThread + + " != " + + sideEffectThread); + } + + responseObserver.onNext(GreetingResponse.newBuilder().setMessage("Hello").build()); + responseObserver.onCompleted(); + } + } + + @Override + protected BindableService checkContextSwitching() { + return new CheckContextSwitching(); + } + + private static class SideEffectGuard extends GreeterGrpc.GreeterImplBase + implements RestateBlockingService { + + @Override + public void greet(GreetingRequest request, StreamObserver responseObserver) { + RestateContext ctx = restateContext(); + ctx.sideEffect( + () -> ctx.oneWayCall(GreeterGrpc.getGreetMethod(), greetingRequest("something"))); + + throw new IllegalStateException("This point should not be reached"); + } + } + + @Override + protected BindableService sideEffectGuard() { + return new SideEffectGuard(); + } + + private static class SideEffectThenAwakeable extends GreeterGrpc.GreeterImplBase + implements RestateBlockingService { + + @Override + public void greet(GreetingRequest request, StreamObserver responseObserver) { + RestateContext ctx = restateContext(); + + ctx.sideEffect( + () -> { + throw new IllegalStateException("This should be replayed"); + }); + ctx.awakeable(TypeTag.BYTES).await(); + + responseObserver.onNext(GreetingResponse.newBuilder().setMessage("Hello").build()); + responseObserver.onCompleted(); + } + } + + @Override + protected BindableService sideEffectThenAwakeable() { + return new SideEffectThenAwakeable(); + } +} diff --git a/sdk-java-blocking/src/test/java/dev/restate/sdk/blocking/SleepTest.java b/sdk-java-blocking/src/test/java/dev/restate/sdk/blocking/SleepTest.java new file mode 100644 index 00000000..9230a608 --- /dev/null +++ b/sdk-java-blocking/src/test/java/dev/restate/sdk/blocking/SleepTest.java @@ -0,0 +1,61 @@ +package dev.restate.sdk.blocking; + +import dev.restate.sdk.core.impl.SleepTestSuite; +import dev.restate.sdk.core.impl.testservices.GreeterGrpc; +import dev.restate.sdk.core.impl.testservices.GreetingRequest; +import dev.restate.sdk.core.impl.testservices.GreetingResponse; +import io.grpc.BindableService; +import io.grpc.stub.StreamObserver; +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; + +public class SleepTest extends SleepTestSuite { + + private static class SleepGreeter extends GreeterGrpc.GreeterImplBase + implements RestateBlockingService { + + @Override + public void greet(GreetingRequest request, StreamObserver responseObserver) { + RestateContext ctx = restateContext(); + + ctx.sleep(Duration.ofMillis(1000)); + + responseObserver.onNext(GreetingResponse.newBuilder().setMessage("Hello").build()); + responseObserver.onCompleted(); + } + } + + @Override + protected BindableService sleepGreeter() { + return new SleepGreeter(); + } + + private static class ManySleeps extends GreeterGrpc.GreeterImplBase + implements RestateBlockingService { + + @Override + public void greet(GreetingRequest request, StreamObserver responseObserver) { + RestateContext ctx = restateContext(); + List> collectedAwaitables = new ArrayList<>(); + + for (int i = 0; i < 10; i++) { + collectedAwaitables.add(ctx.timer(Duration.ofMillis(1000))); + } + + Awaitable.all( + collectedAwaitables.get(0), + collectedAwaitables.get(1), + collectedAwaitables.subList(2, collectedAwaitables.size()).toArray(Awaitable[]::new)) + .await(); + + responseObserver.onNext(GreetingResponse.newBuilder().build()); + responseObserver.onCompleted(); + } + } + + @Override + protected BindableService manySleeps() { + return new ManySleeps(); + } +} diff --git a/sdk-java-blocking/src/test/java/dev/restate/sdk/blocking/StateMachineFailuresTest.java b/sdk-java-blocking/src/test/java/dev/restate/sdk/blocking/StateMachineFailuresTest.java new file mode 100644 index 00000000..7ccd1344 --- /dev/null +++ b/sdk-java-blocking/src/test/java/dev/restate/sdk/blocking/StateMachineFailuresTest.java @@ -0,0 +1,61 @@ +package dev.restate.sdk.blocking; + +import static dev.restate.sdk.core.impl.ProtoUtils.greetingResponse; + +import dev.restate.sdk.core.StateKey; +import dev.restate.sdk.core.TypeTag; +import dev.restate.sdk.core.impl.StateMachineFailuresTestSuite; +import dev.restate.sdk.core.impl.testservices.GreeterGrpc; +import dev.restate.sdk.core.impl.testservices.GreetingRequest; +import dev.restate.sdk.core.impl.testservices.GreetingResponse; +import io.grpc.BindableService; +import io.grpc.stub.StreamObserver; +import java.nio.charset.StandardCharsets; + +class StateMachineFailuresTest extends StateMachineFailuresTestSuite { + + private static class GetState extends GreeterGrpc.GreeterImplBase + implements RestateBlockingService { + + private static final StateKey STATE = + StateKey.of( + "STATE", + TypeTag.using( + i -> Integer.toString(i).getBytes(StandardCharsets.UTF_8), + b -> Integer.parseInt(new String(b, StandardCharsets.UTF_8)))); + + @Override + public void greet(GreetingRequest request, StreamObserver responseObserver) { + restateContext().get(STATE); + responseObserver.onNext(greetingResponse("Francesco")); + responseObserver.onCompleted(); + } + } + + @Override + protected BindableService getState() { + return new GetState(); + } + + private static class SideEffectFailure extends GreeterGrpc.GreeterImplBase + implements RestateBlockingService { + private final TypeTag typeTag; + + private SideEffectFailure(TypeTag typeTag) { + this.typeTag = typeTag; + } + + @Override + public void greet(GreetingRequest request, StreamObserver responseObserver) { + restateContext().sideEffect(typeTag, () -> 0); + + responseObserver.onNext(greetingResponse("Francesco")); + responseObserver.onCompleted(); + } + } + + @Override + protected BindableService sideEffectFailure(TypeTag typeTag) { + return new SideEffectFailure(typeTag); + } +} diff --git a/sdk-java-blocking/src/test/java/dev/restate/sdk/blocking/UserFailuresTest.java b/sdk-java-blocking/src/test/java/dev/restate/sdk/blocking/UserFailuresTest.java new file mode 100644 index 00000000..fde09983 --- /dev/null +++ b/sdk-java-blocking/src/test/java/dev/restate/sdk/blocking/UserFailuresTest.java @@ -0,0 +1,126 @@ +package dev.restate.sdk.blocking; + +import static dev.restate.sdk.core.impl.CoreTestRunner.TestCaseBuilder.testInvocation; +import static dev.restate.sdk.core.impl.ProtoUtils.*; + +import dev.restate.sdk.core.impl.UserFailuresTestSuite; +import dev.restate.sdk.core.impl.testservices.GreeterGrpc; +import dev.restate.sdk.core.impl.testservices.GreetingRequest; +import dev.restate.sdk.core.impl.testservices.GreetingResponse; +import io.grpc.BindableService; +import io.grpc.Status; +import io.grpc.StatusRuntimeException; +import io.grpc.stub.StreamObserver; +import java.util.stream.Stream; + +class UserFailuresTest extends UserFailuresTestSuite { + + private static class ThrowIllegalStateException extends GreeterGrpc.GreeterImplBase { + @Override + public void greet(GreetingRequest request, StreamObserver responseObserver) { + throw new IllegalStateException("Whatever"); + } + } + + @Override + protected BindableService throwIllegalStateException() { + return new ThrowIllegalStateException(); + } + + private static class SideEffectThrowIllegalStateException extends GreeterGrpc.GreeterImplBase + implements RestateBlockingService { + @Override + public void greet(GreetingRequest request, StreamObserver responseObserver) { + restateContext() + .sideEffect( + () -> { + throw new IllegalStateException("Whatever"); + }); + } + } + + @Override + protected BindableService sideEffectThrowIllegalStateException() { + return new SideEffectThrowIllegalStateException(); + } + + private static class ThrowStatusRuntimeException extends GreeterGrpc.GreeterImplBase { + + private final Status status; + + private ThrowStatusRuntimeException(Status status) { + this.status = status; + } + + @Override + public void greet(GreetingRequest request, StreamObserver responseObserver) { + throw new StatusRuntimeException(status); + } + } + + @Override + protected BindableService throwStatusRuntimeException(Status status) { + return new ThrowStatusRuntimeException(status); + } + + private static class SideEffectThrowStatusRuntimeException extends GreeterGrpc.GreeterImplBase + implements RestateBlockingService { + + private final Status status; + + private SideEffectThrowStatusRuntimeException(Status status) { + this.status = status; + } + + @Override + public void greet(GreetingRequest request, StreamObserver responseObserver) { + restateContext() + .sideEffect( + () -> { + throw new StatusRuntimeException(status); + }); + } + } + + @Override + protected BindableService sideEffectThrowStatusRuntimeException(Status status) { + return new SideEffectThrowStatusRuntimeException(status); + } + + // -- Response observer is something specific to the sdk-java-blocking interface + + private static class ResponseObserverOnErrorStatusRuntimeException + extends GreeterGrpc.GreeterImplBase { + @Override + public void greet(GreetingRequest request, StreamObserver responseObserver) { + responseObserver.onError(new StatusRuntimeException(INTERNAL_MY_ERROR)); + } + } + + private static class ResponseObserverOnErrorIllegalStateException + extends GreeterGrpc.GreeterImplBase { + @Override + public void greet(GreetingRequest request, StreamObserver responseObserver) { + responseObserver.onError(new IllegalStateException("Whatever")); + } + } + + @Override + protected Stream definitions() { + return Stream.concat( + super.definitions(), + Stream.of( + testInvocation( + new ResponseObserverOnErrorStatusRuntimeException(), + GreeterGrpc.getGreetMethod()) + .withInput(startMessage(1), inputMessage(GreetingRequest.getDefaultInstance())) + .usingAllThreadingModels() + .expectingOutput(outputMessage(INTERNAL_MY_ERROR)), + testInvocation( + new ResponseObserverOnErrorIllegalStateException(), + GreeterGrpc.getGreetMethod()) + .withInput(startMessage(1), inputMessage(GreetingRequest.getDefaultInstance())) + .usingAllThreadingModels() + .expectingOutput(outputMessage(Status.UNKNOWN)))); + } +} diff --git a/sdk-kotlin/build.gradle.kts b/sdk-kotlin/build.gradle.kts index 1497816e..225cebf2 100644 --- a/sdk-kotlin/build.gradle.kts +++ b/sdk-kotlin/build.gradle.kts @@ -1,3 +1,5 @@ +import com.google.protobuf.gradle.id + // Without these suppressions version catalog usage here and in other build // files is marked red by IntelliJ: // https://youtrack.jetbrains.com/issue/KTIJ-19369. @@ -7,6 +9,7 @@ "UNRESOLVED_REFERENCE_WRONG_RECEIVER", "FUNCTION_CALL_EXPECTED") plugins { + java kotlin("jvm") idea `maven-publish` @@ -16,9 +19,46 @@ dependencies { api(project(":sdk-core")) implementation(kotlinLibs.kotlinx.coroutines) + + testImplementation(project(":sdk-core-impl")) + testImplementation(testingLibs.junit.jupiter) + testImplementation(testingLibs.assertj) + testImplementation(coreLibs.protobuf.java) + testImplementation(coreLibs.protobuf.kotlin) + testImplementation(coreLibs.grpc.stub) + testImplementation(coreLibs.grpc.protobuf) + testImplementation(coreLibs.grpc.kotlin.stub) + testImplementation(coreLibs.log4j.core) + + testImplementation(project(":sdk-core-impl", "testArchive")) + testProtobuf(project(":sdk-core-impl", "testArchive")) +} + +configure { + kotlin { + ktfmt() + targetExclude("build/generated/**/*.kt") + } } -configure { kotlin { ktfmt() } } +protobuf { + plugins { + id("grpc") { artifact = "io.grpc:protoc-gen-grpc-java:${coreLibs.versions.grpc.get()}" } + id("grpckt") { + artifact = "io.grpc:protoc-gen-grpc-kotlin:${coreLibs.versions.grpckt.get()}:jdk8@jar" + } + } + + generateProtoTasks { + ofSourceSet("test").forEach { + it.plugins { + id("grpc") + id("grpckt") + } + it.builtins { id("kotlin") } + } + } +} publishing { publications { diff --git a/sdk-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/RestateContextImpl.kt b/sdk-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/RestateContextImpl.kt index 1e083b2d..b955f93a 100644 --- a/sdk-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/RestateContextImpl.kt +++ b/sdk-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/RestateContextImpl.kt @@ -70,8 +70,8 @@ internal class RestateContextImpl internal constructor(private val syscalls: Sys return NonNullAwaitableImpl(syscalls, deferredResult) } - override suspend fun oneWayCall( - methodDescriptor: MethodDescriptor, + override suspend fun oneWayCall( + methodDescriptor: MethodDescriptor, parameter: T ) { return suspendCancellableCoroutine { cont: CancellableContinuation -> @@ -79,8 +79,8 @@ internal class RestateContextImpl internal constructor(private val syscalls: Sys } } - override suspend fun delayedCall( - methodDescriptor: MethodDescriptor, + override suspend fun delayedCall( + methodDescriptor: MethodDescriptor, parameter: T, delay: Duration ) { @@ -90,20 +90,25 @@ internal class RestateContextImpl internal constructor(private val syscalls: Sys } } - override suspend fun sideEffect(typeTag: TypeTag, sideEffectAction: suspend () -> T?): T? { + override suspend fun sideEffect( + typeTag: TypeTag, + sideEffectAction: suspend () -> T + ): T { val exitResult = - suspendCancellableCoroutine { cont: CancellableContinuation> -> + suspendCancellableCoroutine { cont: CancellableContinuation> -> syscalls.enterSideEffectBlock( typeTag, - object : EnterSideEffectSyscallCallback { + object : EnterSideEffectSyscallCallback { + @Suppress("UNCHECKED_CAST") override fun onResult(t: T?) { - val deferred: CompletableDeferred = CompletableDeferred() - deferred.complete(t) + val deferred: CompletableDeferred = CompletableDeferred() + // This unchecked cast is fine because T is declared as Any? + deferred.complete(t as T) cont.resume(deferred) } override fun onFailure(t: StatusRuntimeException) { - val deferred: CompletableDeferred = CompletableDeferred() + val deferred: CompletableDeferred = CompletableDeferred() deferred.completeExceptionally(t) cont.resume(deferred) } @@ -131,9 +136,11 @@ internal class RestateContextImpl internal constructor(private val syscalls: Sys } val exitCallback = - object : ExitSideEffectSyscallCallback { + object : ExitSideEffectSyscallCallback { + @Suppress("UNCHECKED_CAST") override fun onResult(t: T?) { - exitResult.complete(t) + // This unchecked cast is fine because T is declared as Any? + exitResult.complete(t as T) } override fun onFailure(t: StatusRuntimeException) { diff --git a/sdk-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/api.kt b/sdk-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/api.kt index ff92346e..73a9c5cd 100644 --- a/sdk-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/api.kt +++ b/sdk-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/api.kt @@ -101,8 +101,8 @@ sealed interface RestateContext { * generated `*Grpc` class. * @param parameter the invocation request parameter. */ - suspend fun oneWayCall( - methodDescriptor: MethodDescriptor, + suspend fun oneWayCall( + methodDescriptor: MethodDescriptor, parameter: T ) @@ -117,8 +117,8 @@ sealed interface RestateContext { * @param parameter the invocation request parameter. * @param delay time to wait before executing the call */ - suspend fun delayedCall( - methodDescriptor: MethodDescriptor, + suspend fun delayedCall( + methodDescriptor: MethodDescriptor, parameter: T, delay: Duration ) @@ -134,7 +134,7 @@ sealed interface RestateContext { * @param T type of the return value. * @return value of the side effect operation. */ - suspend fun sideEffect(typeTag: TypeTag, sideEffectAction: suspend () -> T?): T? + suspend fun sideEffect(typeTag: TypeTag, sideEffectAction: suspend () -> T): T /** Like [sideEffect] without a return value. */ suspend fun sideEffect(sideEffectAction: suspend () -> Unit) { diff --git a/sdk-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/AwakeableIdTest.kt b/sdk-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/AwakeableIdTest.kt new file mode 100644 index 00000000..6973fde9 --- /dev/null +++ b/sdk-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/AwakeableIdTest.kt @@ -0,0 +1,25 @@ +package dev.restate.sdk.kotlin + +import dev.restate.sdk.core.TypeTag +import dev.restate.sdk.core.impl.AwakeableIdTestSuite +import dev.restate.sdk.core.impl.testservices.GreeterGrpcKt +import dev.restate.sdk.core.impl.testservices.GreetingRequest +import dev.restate.sdk.core.impl.testservices.GreetingResponse +import dev.restate.sdk.core.impl.testservices.greetingResponse +import io.grpc.BindableService +import kotlinx.coroutines.Dispatchers + +internal class AwakeableIdTest : AwakeableIdTestSuite() { + private class ReturnAwakeableId : + GreeterGrpcKt.GreeterCoroutineImplBase(Dispatchers.Unconfined), RestateCoroutineService { + + override suspend fun greet(request: GreetingRequest): GreetingResponse { + val id: String = restateContext().awakeable(TypeTag.STRING_UTF8).id + return greetingResponse { message = id } + } + } + + override fun returnAwakeableId(): BindableService { + return ReturnAwakeableId() + } +} diff --git a/sdk-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/DeferredTest.kt b/sdk-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/DeferredTest.kt new file mode 100644 index 00000000..9946551a --- /dev/null +++ b/sdk-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/DeferredTest.kt @@ -0,0 +1,136 @@ +package dev.restate.sdk.kotlin + +import dev.restate.sdk.core.StateKey +import dev.restate.sdk.core.TypeTag +import dev.restate.sdk.core.impl.DeferredTestSuite +import dev.restate.sdk.core.impl.testservices.* +import io.grpc.BindableService +import kotlinx.coroutines.Dispatchers + +internal class DeferredTest : DeferredTestSuite() { + private class ReverseAwaitOrder : + GreeterGrpcKt.GreeterCoroutineImplBase(Dispatchers.Unconfined), RestateCoroutineService { + override suspend fun greet(request: GreetingRequest): GreetingResponse { + val ctx = restateContext() + val a1 = ctx.callAsync(GreeterGrpcKt.greetMethod, greetingRequest { name = "Francesco" }) + val a2 = ctx.callAsync(GreeterGrpcKt.greetMethod, greetingRequest { name = "Till" }) + val a2Res = a2.await().getMessage() + ctx.set(StateKey.of("A2", TypeTag.STRING_UTF8), a2Res) + val a1Res = a1.await().getMessage() + return greetingResponse { message = "$a1Res-$a2Res" } + } + } + + override fun reverseAwaitOrder(): BindableService { + return ReverseAwaitOrder() + } + + private class AwaitTwiceTheSameAwaitable : + GreeterGrpcKt.GreeterCoroutineImplBase(Dispatchers.Unconfined), RestateCoroutineService { + override suspend fun greet(request: GreetingRequest): GreetingResponse { + val ctx = restateContext() + val a = ctx.callAsync(GreeterGrpcKt.greetMethod, greetingRequest { name = "Francesco" }) + return greetingResponse { message = a.await().getMessage() + "-" + a.await().getMessage() } + } + } + + override fun awaitTwiceTheSameAwaitable(): BindableService { + return AwaitTwiceTheSameAwaitable() + } + + private class AwaitAll : + GreeterGrpcKt.GreeterCoroutineImplBase(Dispatchers.Unconfined), RestateCoroutineService { + override suspend fun greet(request: GreetingRequest): GreetingResponse { + val ctx = restateContext() + val a1 = ctx.callAsync(GreeterGrpcKt.greetMethod, greetingRequest { name = "Francesco" }) + val a2 = ctx.callAsync(GreeterGrpcKt.greetMethod, greetingRequest { name = "Till" }) + listOf(a1, a2).awaitAll() + + return greetingResponse { message = a1.await().getMessage() + "-" + a2.await().getMessage() } + } + } + + override fun awaitAll(): BindableService { + return AwaitAll() + } + + private class AwaitAny : + GreeterGrpcKt.GreeterCoroutineImplBase(Dispatchers.Unconfined), RestateCoroutineService { + override suspend fun greet(request: GreetingRequest): GreetingResponse { + val ctx = restateContext() + val a1 = ctx.callAsync(GreeterGrpcKt.greetMethod, greetingRequest { name = "Francesco" }) + val a2 = ctx.callAsync(GreeterGrpcKt.greetMethod, greetingRequest { name = "Till" }) + return Awaitable.any(a1, a2).await() as GreetingResponse + } + } + + override fun awaitAny(): BindableService { + return AwaitAny() + } + + private class CombineAnyWithAll : + GreeterGrpcKt.GreeterCoroutineImplBase(Dispatchers.Unconfined), RestateCoroutineService { + override suspend fun greet(request: GreetingRequest): GreetingResponse { + val ctx = restateContext() + val a1 = ctx.awakeable(TypeTag.STRING_UTF8) + val a2 = ctx.awakeable(TypeTag.STRING_UTF8) + val a3 = ctx.awakeable(TypeTag.STRING_UTF8) + val a4 = ctx.awakeable(TypeTag.STRING_UTF8) + val a12 = Awaitable.any(a1, a2) + val a23 = Awaitable.any(a2, a3) + val a34 = Awaitable.any(a3, a4) + Awaitable.all(a12, a23, a34).await() + + return greetingResponse { + message = a12.await().toString() + a23.await() as String? + a34.await() + } + } + } + + override fun combineAnyWithAll(): BindableService { + return CombineAnyWithAll() + } + + private class AwaitAnyIndex : + GreeterGrpcKt.GreeterCoroutineImplBase(Dispatchers.Unconfined), RestateCoroutineService { + override suspend fun greet(request: GreetingRequest): GreetingResponse { + val ctx = restateContext() + val a1 = ctx.awakeable(TypeTag.STRING_UTF8) + val a2 = ctx.awakeable(TypeTag.STRING_UTF8) + val a3 = ctx.awakeable(TypeTag.STRING_UTF8) + val a4 = ctx.awakeable(TypeTag.STRING_UTF8) + + return greetingResponse { + message = Awaitable.any(a1, Awaitable.all(a2, a3), a4).awaitIndex().toString() + } + } + } + + override fun awaitAnyIndex(): BindableService { + return AwaitAnyIndex() + } + + private class AwaitOnAlreadyResolvedAwaitables : + GreeterGrpcKt.GreeterCoroutineImplBase(Dispatchers.Unconfined), RestateCoroutineService { + override suspend fun greet(request: GreetingRequest): GreetingResponse { + val ctx = restateContext() + val a1 = ctx.awakeable(TypeTag.STRING_UTF8) + val a2 = ctx.awakeable(TypeTag.STRING_UTF8) + val a12 = Awaitable.all(a1, a2) + val a12and1 = Awaitable.all(a12, a1) + val a121and12 = Awaitable.all(a12and1, a12) + a12and1.await() + a121and12.await() + + return greetingResponse { message = a1.await() + a2.await() } + } + } + + override fun awaitOnAlreadyResolvedAwaitables(): BindableService { + return AwaitOnAlreadyResolvedAwaitables() + } + + override fun awaitWithTimeout(): BindableService { + throw UnsupportedOperationException("Not supported yet") + } +} diff --git a/sdk-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/EagerStateTest.kt b/sdk-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/EagerStateTest.kt new file mode 100644 index 00000000..371f02c3 --- /dev/null +++ b/sdk-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/EagerStateTest.kt @@ -0,0 +1,70 @@ +package dev.restate.sdk.kotlin + +import dev.restate.sdk.core.StateKey +import dev.restate.sdk.core.TypeTag +import dev.restate.sdk.core.impl.EagerStateTestSuite +import dev.restate.sdk.core.impl.testservices.GreeterGrpcKt +import dev.restate.sdk.core.impl.testservices.GreetingRequest +import dev.restate.sdk.core.impl.testservices.GreetingResponse +import dev.restate.sdk.core.impl.testservices.greetingResponse +import io.grpc.BindableService +import kotlinx.coroutines.Dispatchers +import org.assertj.core.api.AssertionsForClassTypes.assertThat + +internal class EagerStateTest : EagerStateTestSuite() { + private class GetEmpty : + GreeterGrpcKt.GreeterCoroutineImplBase(Dispatchers.Unconfined), RestateCoroutineService { + override suspend fun greet(request: GreetingRequest): GreetingResponse { + val ctx = restateContext() + val stateIsEmpty = ctx.get(StateKey.of("STATE", TypeTag.STRING_UTF8)) == null + return greetingResponse { message = stateIsEmpty.toString() } + } + } + + override fun getEmpty(): BindableService { + return GetEmpty() + } + + private class Get : + GreeterGrpcKt.GreeterCoroutineImplBase(Dispatchers.Unconfined), RestateCoroutineService { + override suspend fun greet(request: GreetingRequest): GreetingResponse { + return greetingResponse { + message = restateContext().get(StateKey.of("STATE", TypeTag.STRING_UTF8))!! + } + } + } + + override fun get(): BindableService { + return Get() + } + + private class GetAppendAndGet : + GreeterGrpcKt.GreeterCoroutineImplBase(Dispatchers.Unconfined), RestateCoroutineService { + override suspend fun greet(request: GreetingRequest): GreetingResponse { + val ctx = restateContext() + val oldState = ctx.get(StateKey.of("STATE", TypeTag.STRING_UTF8))!! + ctx.set(StateKey.of("STATE", TypeTag.STRING_UTF8), oldState + request.getName()) + val newState = ctx.get(StateKey.of("STATE", TypeTag.STRING_UTF8))!! + return greetingResponse { message = newState } + } + } + + override fun getAppendAndGet(): BindableService { + return GetAppendAndGet() + } + + private class GetClearAndGet : + GreeterGrpcKt.GreeterCoroutineImplBase(Dispatchers.Unconfined), RestateCoroutineService { + override suspend fun greet(request: GreetingRequest): GreetingResponse { + val ctx = restateContext() + val oldState = ctx.get(StateKey.of("STATE", TypeTag.STRING_UTF8))!! + ctx.clear(StateKey.of("STATE", TypeTag.STRING_UTF8)) + assertThat(ctx.get(StateKey.of("STATE", TypeTag.STRING_UTF8))).isNull() + return greetingResponse { message = oldState } + } + } + + override fun getClearAndGet(): BindableService { + return GetClearAndGet() + } +} diff --git a/sdk-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/GetAndSetStateTest.kt b/sdk-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/GetAndSetStateTest.kt new file mode 100644 index 00000000..00259ccd --- /dev/null +++ b/sdk-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/GetAndSetStateTest.kt @@ -0,0 +1,33 @@ +package dev.restate.sdk.kotlin + +import dev.restate.sdk.core.StateKey +import dev.restate.sdk.core.TypeTag +import dev.restate.sdk.core.impl.GetAndSetStateTestSuite +import dev.restate.sdk.core.impl.testservices.GreeterGrpcKt +import dev.restate.sdk.core.impl.testservices.GreetingRequest +import dev.restate.sdk.core.impl.testservices.GreetingResponse +import dev.restate.sdk.core.impl.testservices.greetingResponse +import io.grpc.BindableService +import kotlinx.coroutines.Dispatchers + +internal class GetAndSetStateTest : GetAndSetStateTestSuite() { + private class GetAndSetGreeter : + GreeterGrpcKt.GreeterCoroutineImplBase(Dispatchers.Unconfined), RestateCoroutineService { + override suspend fun greet(request: GreetingRequest): GreetingResponse { + val ctx = restateContext() + + val state = ctx.get(StateKey.of("STATE", TypeTag.STRING_UTF8))!! + ctx.set(StateKey.of("STATE", TypeTag.STRING_UTF8), request.getName()) + + return greetingResponse { message = "Hello $state" } + } + } + + override fun getAndSetGreeter(): BindableService { + return GetAndSetGreeter() + } + + override fun setNullState(): BindableService { + throw UnsupportedOperationException("The kotlin type system enforces non null state values") + } +} diff --git a/sdk-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/GetStateTest.kt b/sdk-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/GetStateTest.kt new file mode 100644 index 00000000..94d684b8 --- /dev/null +++ b/sdk-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/GetStateTest.kt @@ -0,0 +1,26 @@ +package dev.restate.sdk.kotlin + +import dev.restate.sdk.core.StateKey +import dev.restate.sdk.core.TypeTag +import dev.restate.sdk.core.impl.GetStateTestSuite +import dev.restate.sdk.core.impl.testservices.GreeterGrpcKt +import dev.restate.sdk.core.impl.testservices.GreetingRequest +import dev.restate.sdk.core.impl.testservices.GreetingResponse +import dev.restate.sdk.core.impl.testservices.greetingResponse +import io.grpc.BindableService +import kotlinx.coroutines.Dispatchers + +internal class GetStateTest : GetStateTestSuite() { + private class GetStateGreeter : + GreeterGrpcKt.GreeterCoroutineImplBase(Dispatchers.Unconfined), RestateCoroutineService { + override suspend fun greet(request: GreetingRequest): GreetingResponse { + val state: String = + restateContext().get(StateKey.of("STATE", TypeTag.STRING_UTF8)) ?: "Unknown" + return greetingResponse { message = "Hello $state" } + } + } + + override fun getStateGreeter(): BindableService { + return GetStateGreeter() + } +} diff --git a/sdk-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/InvocationIdTest.kt b/sdk-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/InvocationIdTest.kt new file mode 100644 index 00000000..8465cf36 --- /dev/null +++ b/sdk-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/InvocationIdTest.kt @@ -0,0 +1,23 @@ +package dev.restate.sdk.kotlin + +import dev.restate.sdk.core.InvocationId +import dev.restate.sdk.core.impl.InvocationIdTestSuite +import dev.restate.sdk.core.impl.testservices.GreeterGrpcKt +import dev.restate.sdk.core.impl.testservices.GreetingRequest +import dev.restate.sdk.core.impl.testservices.GreetingResponse +import dev.restate.sdk.core.impl.testservices.greetingResponse +import io.grpc.BindableService +import kotlinx.coroutines.Dispatchers + +internal class InvocationIdTest : InvocationIdTestSuite() { + private class ReturnInvocationId : + GreeterGrpcKt.GreeterCoroutineImplBase(Dispatchers.Unconfined), RestateCoroutineService { + override suspend fun greet(request: GreetingRequest): GreetingResponse { + return greetingResponse { message = InvocationId.current().toString() } + } + } + + override fun returnInvocationId(): BindableService { + return ReturnInvocationId() + } +} diff --git a/sdk-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/OnlyInputAndOutputTest.kt b/sdk-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/OnlyInputAndOutputTest.kt new file mode 100644 index 00000000..de3c5672 --- /dev/null +++ b/sdk-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/OnlyInputAndOutputTest.kt @@ -0,0 +1,22 @@ +package dev.restate.sdk.kotlin + +import dev.restate.sdk.core.impl.OnlyInputAndOutputTestSuite +import dev.restate.sdk.core.impl.testservices.GreeterGrpcKt +import dev.restate.sdk.core.impl.testservices.GreetingRequest +import dev.restate.sdk.core.impl.testservices.GreetingResponse +import dev.restate.sdk.core.impl.testservices.greetingResponse +import io.grpc.BindableService +import kotlinx.coroutines.Dispatchers + +internal class OnlyInputAndOutputTest : OnlyInputAndOutputTestSuite() { + private class NoSyscallsGreeter : + GreeterGrpcKt.GreeterCoroutineImplBase(Dispatchers.Unconfined), RestateCoroutineService { + override suspend fun greet(request: GreetingRequest): GreetingResponse { + return greetingResponse { message = "Hello " + request.getName() } + } + } + + override fun noSyscallsGreeter(): BindableService { + return NoSyscallsGreeter() + } +} diff --git a/sdk-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/SideEffectTest.kt b/sdk-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/SideEffectTest.kt new file mode 100644 index 00000000..a14b80ec --- /dev/null +++ b/sdk-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/SideEffectTest.kt @@ -0,0 +1,87 @@ +package dev.restate.sdk.kotlin + +import dev.restate.sdk.core.TypeTag +import dev.restate.sdk.core.impl.SideEffectTestSuite +import dev.restate.sdk.core.impl.testservices.* +import io.grpc.BindableService +import java.util.* +import kotlinx.coroutines.CoroutineName +import kotlinx.coroutines.Dispatchers + +internal class SideEffectTest : SideEffectTestSuite() { + private class SideEffect(private val sideEffectOutput: String) : + GreeterGrpcKt.GreeterCoroutineImplBase(Dispatchers.Unconfined), RestateCoroutineService { + override suspend fun greet(request: GreetingRequest): GreetingResponse { + val ctx: RestateContext = restateContext() + val result = ctx.sideEffect(TypeTag.STRING_UTF8) { sideEffectOutput } + return greetingResponse { message = "Hello $result" } + } + } + + override fun sideEffect(sideEffectOutput: String): BindableService { + return SideEffect(sideEffectOutput) + } + + private class ConsecutiveSideEffect(private val sideEffectOutput: String) : + GreeterGrpcKt.GreeterCoroutineImplBase(Dispatchers.Unconfined), RestateCoroutineService { + override suspend fun greet(request: GreetingRequest): GreetingResponse { + val ctx: RestateContext = restateContext() + val firstResult = ctx.sideEffect(TypeTag.STRING_UTF8) { sideEffectOutput } + val secondResult = + ctx.sideEffect(TypeTag.STRING_UTF8) { firstResult.uppercase(Locale.getDefault()) } + return greetingResponse { message = "Hello $secondResult" } + } + } + + override fun consecutiveSideEffect(sideEffectOutput: String): BindableService { + return ConsecutiveSideEffect(sideEffectOutput) + } + + private class CheckContextSwitching : + GreeterGrpcKt.GreeterCoroutineImplBase( + Dispatchers.Unconfined + CoroutineName("CheckContextSwitchingTestCoroutine")), + RestateCoroutineService { + + override suspend fun greet(request: GreetingRequest): GreetingResponse { + val sideEffectThread = + restateContext().sideEffect(TypeTag.STRING_UTF8) { Thread.currentThread().name } + check(sideEffectThread.contains("CheckContextSwitchingTestCoroutine")) { + "Side effect thread is not running within the same coroutine context of the handler method: $sideEffectThread" + } + return greetingResponse { message = "Hello" } + } + } + + override fun checkContextSwitching(): BindableService { + return CheckContextSwitching() + } + + private class SideEffectGuard : + GreeterGrpcKt.GreeterCoroutineImplBase(Dispatchers.Unconfined), RestateCoroutineService { + override suspend fun greet(request: GreetingRequest): GreetingResponse { + val ctx = restateContext() + ctx.sideEffect { + ctx.oneWayCall(GreeterGrpcKt.greetMethod, greetingRequest { name = "something" }) + } + throw IllegalStateException("This point should not be reached") + } + } + + override fun sideEffectGuard(): BindableService { + return SideEffectGuard() + } + + private class SideEffectThenAwakeable : + GreeterGrpcKt.GreeterCoroutineImplBase(Dispatchers.Unconfined), RestateCoroutineService { + override suspend fun greet(request: GreetingRequest): GreetingResponse { + val ctx = restateContext() + ctx.sideEffect { throw IllegalStateException("This should be replayed") } + ctx.awakeable(TypeTag.BYTES).await() + return greetingResponse { message = "Hello" } + } + } + + override fun sideEffectThenAwakeable(): BindableService { + return SideEffectThenAwakeable() + } +} diff --git a/sdk-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/SleepTest.kt b/sdk-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/SleepTest.kt new file mode 100644 index 00000000..98b5a95b --- /dev/null +++ b/sdk-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/SleepTest.kt @@ -0,0 +1,42 @@ +package dev.restate.sdk.kotlin + +import dev.restate.sdk.core.impl.SleepTestSuite +import dev.restate.sdk.core.impl.testservices.GreeterGrpcKt +import dev.restate.sdk.core.impl.testservices.GreetingRequest +import dev.restate.sdk.core.impl.testservices.GreetingResponse +import dev.restate.sdk.core.impl.testservices.greetingResponse +import io.grpc.BindableService +import kotlin.time.Duration.Companion.milliseconds +import kotlinx.coroutines.Dispatchers + +class SleepTest : SleepTestSuite() { + private class SleepGreeter : + GreeterGrpcKt.GreeterCoroutineImplBase(Dispatchers.Unconfined), RestateCoroutineService { + override suspend fun greet(request: GreetingRequest): GreetingResponse { + val ctx = restateContext() + ctx.sleep(1000.milliseconds) + return greetingResponse { message = "Hello" } + } + } + + override fun sleepGreeter(): BindableService { + return SleepGreeter() + } + + private class ManySleeps : + GreeterGrpcKt.GreeterCoroutineImplBase(Dispatchers.Unconfined), RestateCoroutineService { + override suspend fun greet(request: GreetingRequest): GreetingResponse { + val ctx = restateContext() + val awaitables = mutableListOf>() + for (i in 0..9) { + awaitables.add(ctx.timer(1000.milliseconds)) + } + awaitables.awaitAll() + return greetingResponse {} + } + } + + override fun manySleeps(): BindableService { + return ManySleeps() + } +} diff --git a/sdk-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/StateMachineFailuresTest.kt b/sdk-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/StateMachineFailuresTest.kt new file mode 100644 index 00000000..1f90430a --- /dev/null +++ b/sdk-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/StateMachineFailuresTest.kt @@ -0,0 +1,48 @@ +package dev.restate.sdk.kotlin + +import dev.restate.sdk.core.StateKey +import dev.restate.sdk.core.TypeTag +import dev.restate.sdk.core.impl.StateMachineFailuresTestSuite +import dev.restate.sdk.core.impl.testservices.GreeterGrpcKt +import dev.restate.sdk.core.impl.testservices.GreetingRequest +import dev.restate.sdk.core.impl.testservices.GreetingResponse +import dev.restate.sdk.core.impl.testservices.greetingResponse +import io.grpc.BindableService +import java.nio.charset.StandardCharsets +import kotlinx.coroutines.Dispatchers + +internal class StateMachineFailuresTest : StateMachineFailuresTestSuite() { + private class GetState : + GreeterGrpcKt.GreeterCoroutineImplBase(Dispatchers.Unconfined), RestateCoroutineService { + override suspend fun greet(request: GreetingRequest): GreetingResponse { + restateContext().get(STATE) + return greetingResponse { message = "Francesco" } + } + + companion object { + private val STATE = + StateKey.of( + "STATE", + TypeTag.using({ i: Int -> i.toString().toByteArray(StandardCharsets.UTF_8) }) { + b: ByteArray? -> + String(b!!, StandardCharsets.UTF_8).toInt() + }) + } + } + + override fun getState(): BindableService { + return GetState() + } + + private class SideEffectFailure(private val typeTag: TypeTag) : + GreeterGrpcKt.GreeterCoroutineImplBase(Dispatchers.Unconfined), RestateCoroutineService { + override suspend fun greet(request: GreetingRequest): GreetingResponse { + restateContext().sideEffect(typeTag) { 0 } + return greetingResponse { message = "Francesco" } + } + } + + override fun sideEffectFailure(typeTag: TypeTag): BindableService { + return SideEffectFailure(typeTag) + } +} diff --git a/sdk-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/UserFailuresTest.kt b/sdk-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/UserFailuresTest.kt new file mode 100644 index 00000000..e0d5bff8 --- /dev/null +++ b/sdk-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/UserFailuresTest.kt @@ -0,0 +1,59 @@ +package dev.restate.sdk.kotlin + +import dev.restate.sdk.core.impl.UserFailuresTestSuite +import dev.restate.sdk.core.impl.testservices.GreeterGrpcKt +import dev.restate.sdk.core.impl.testservices.GreetingRequest +import dev.restate.sdk.core.impl.testservices.GreetingResponse +import io.grpc.BindableService +import io.grpc.Status +import io.grpc.StatusRuntimeException +import java.lang.UnsupportedOperationException +import kotlinx.coroutines.Dispatchers + +internal class UserFailuresTest : UserFailuresTestSuite() { + private class ThrowIllegalStateException : + GreeterGrpcKt.GreeterCoroutineImplBase(Dispatchers.Unconfined), RestateCoroutineService { + override suspend fun greet(request: GreetingRequest): GreetingResponse { + throw IllegalStateException("Whatever") + } + } + + override fun throwIllegalStateException(): BindableService { + throw UnsupportedOperationException("https://github.com/restatedev/sdk-java/issues/116") + } + + private class SideEffectThrowIllegalStateException : + GreeterGrpcKt.GreeterCoroutineImplBase(Dispatchers.Unconfined), RestateCoroutineService { + override suspend fun greet(request: GreetingRequest): GreetingResponse { + restateContext().sideEffect { throw IllegalStateException("Whatever") } + throw IllegalStateException("Not expected to reach this point") + } + } + + override fun sideEffectThrowIllegalStateException(): BindableService { + return SideEffectThrowIllegalStateException() + } + + private class ThrowStatusRuntimeException(private val status: Status) : + GreeterGrpcKt.GreeterCoroutineImplBase(Dispatchers.Unconfined), RestateCoroutineService { + override suspend fun greet(request: GreetingRequest): GreetingResponse { + throw StatusRuntimeException(status) + } + } + + override fun throwStatusRuntimeException(status: Status): BindableService { + return ThrowStatusRuntimeException(status) + } + + private class SideEffectThrowStatusRuntimeException(private val status: Status) : + GreeterGrpcKt.GreeterCoroutineImplBase(Dispatchers.Unconfined), RestateCoroutineService { + override suspend fun greet(request: GreetingRequest): GreetingResponse { + restateContext().sideEffect { throw StatusRuntimeException(status) } + throw IllegalStateException("Not expected to reach this point") + } + } + + override fun sideEffectThrowStatusRuntimeException(status: Status): BindableService { + return SideEffectThrowStatusRuntimeException(status) + } +}