Skip to content

Commit

Permalink
Add Kotlin tests (#117)
Browse files Browse the repository at this point in the history
* Split test suite code and test services. With this change we can now reuse the same tests for the kotlin interface.
* Add tests to kotlin interface
* A bunch of fixes for issues that came up while testing the kotlin sdk.
  • Loading branch information
slinkydeveloper authored Oct 17, 2023
1 parent 57b062a commit 8223cd6
Show file tree
Hide file tree
Showing 46 changed files with 1,968 additions and 971 deletions.
1 change: 0 additions & 1 deletion sdk-core-impl/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ dependencies {

testCompileOnly(coreLibs.javax.annotation.api)

testImplementation(project(":sdk-java-blocking"))
testImplementation(testingLibs.junit.jupiter)
testImplementation(testingLibs.assertj)
testImplementation(coreLibs.grpc.stub)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,9 @@ private <T> T deserialize(TypeTag<T> ty, ByteString bytes) {
return (T) bytes.toByteArray();
} else if (ByteString.class.equals(typeTag)) {
return (T) bytes;
} else if (Void.class.equals(typeTag)) {
} else if (Void.class.equals(typeTag) || Void.TYPE.equals(typeTag)) {
// Amazing JVM foot-gun here: Void.TYPE is the primitive type, Void.class is the boxed type.
// For us, they're the same but for the equality they aren't, so we check both
return null;
}
return serde.deserialize(typeTag, bytes.toByteArray());
Expand Down
Original file line number Diff line number Diff line change
@@ -1,39 +1,28 @@
package dev.restate.sdk.core.impl;

import static dev.restate.sdk.core.impl.CoreTestRunner.TestCaseBuilder.testInvocation;
import static dev.restate.sdk.core.impl.ProtoUtils.*;
import static dev.restate.sdk.core.impl.ProtoUtils.inputMessage;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.InstanceOfAssertFactories.type;

import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException;
import dev.restate.generated.service.protocol.Protocol;
import dev.restate.sdk.blocking.RestateBlockingService;
import dev.restate.sdk.core.TypeTag;
import dev.restate.sdk.core.impl.testservices.GreeterGrpc;
import dev.restate.sdk.core.impl.testservices.GreetingRequest;
import dev.restate.sdk.core.impl.testservices.GreetingResponse;
import io.grpc.stub.StreamObserver;
import io.grpc.BindableService;
import java.nio.ByteBuffer;
import java.util.Base64;
import java.util.UUID;
import java.util.stream.Stream;

class AwakeableIdTest extends CoreTestRunner {
public abstract class AwakeableIdTestSuite extends CoreTestRunner {

private static class ReturnAwakeableId extends GreeterGrpc.GreeterImplBase
implements RestateBlockingService {

@Override
public void greet(GreetingRequest request, StreamObserver<GreetingResponse> responseObserver) {
String id = restateContext().awakeable(TypeTag.STRING_UTF8).id();
responseObserver.onNext(greetingResponse(id));
responseObserver.onCompleted();
}
}
protected abstract BindableService returnAwakeableId();

@Override
Stream<TestDefinition> definitions() {
protected Stream<TestDefinition> definitions() {
UUID id = UUID.randomUUID();
String debugId = id.toString();
byte[] serializedId = serializeUUID(id);
Expand All @@ -46,7 +35,7 @@ Stream<TestDefinition> definitions() {
Base64.getUrlEncoder().encodeToString(expectedAwakeableId.array());

return Stream.of(
testInvocation(new ReturnAwakeableId(), GreeterGrpc.getGreetMethod())
testInvocation(this::returnAwakeableId, GreeterGrpc.getGreetMethod())
.withInput(
Protocol.StartMessage.newBuilder()
.setDebugId(debugId)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package dev.restate.sdk.core.impl;

import static dev.restate.sdk.core.impl.CoreTestRunner.TestCaseBuilder.testInvocation;
import static dev.restate.sdk.core.impl.ProtoUtils.headerFromMessage;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.params.provider.Arguments.arguments;
Expand All @@ -17,29 +16,29 @@
import io.grpc.MethodDescriptor;
import io.grpc.ServerServiceDefinition;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.*;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.annotation.Nullable;
import org.junit.jupiter.api.TestInstance;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;

@TestInstance(TestInstance.Lifecycle.PER_CLASS)
abstract class CoreTestRunner {
public abstract class CoreTestRunner {

abstract Stream<TestDefinition> definitions();
protected abstract Stream<TestDefinition> definitions();

Stream<Arguments> source() {
return definitions()
.filter(TestDefinition::isValid)
.flatMap(
c ->
c.getThreadingModels().stream()
Expand Down Expand Up @@ -125,12 +124,12 @@ void executeTest(
}
}

enum ThreadingModel {
public enum ThreadingModel {
BUFFERED_SINGLE_THREAD,
UNBUFFERED_MULTI_THREAD
}

interface TestDefinition {
public interface TestDefinition {
ServerServiceDefinition getService();

String getMethod();
Expand All @@ -142,38 +141,49 @@ interface TestDefinition {
BiConsumer<FutureSubscriber<MessageLite>, Duration> getOutputAssert();

String testCaseName();

boolean isValid();
}

/** Builder for the test cases */
static class TestCaseBuilder {
public static class TestCaseBuilder {

static TestInvocationBuilder testInvocation(BindableService svc, String method) {
public static TestInvocationBuilder testInvocation(BindableService svc, String method) {
return new TestInvocationBuilder(svc, method);
}

static TestInvocationBuilder testInvocation(
public static TestInvocationBuilder testInvocation(
BindableService svc, MethodDescriptor<?, ?> method) {
return testInvocation(svc, method.getBareMethodName());
}

static class TestInvocationBuilder {
protected final BindableService svc;
public static TestInvocationBuilder testInvocation(
Supplier<BindableService> svc, MethodDescriptor<?, ?> method) {
try {
return testInvocation(svc.get(), method.getBareMethodName());
} catch (UnsupportedOperationException e) {
return testInvocation(null, method.getBareMethodName());
}
}

public static class TestInvocationBuilder {
protected final @Nullable BindableService svc;
protected final String method;

TestInvocationBuilder(BindableService svc, String method) {
TestInvocationBuilder(@Nullable BindableService svc, String method) {
this.svc = svc;
this.method = method;
}

WithInputBuilder withInput(short flags, MessageLiteOrBuilder msgOrBuilder) {
public WithInputBuilder withInput(short flags, MessageLiteOrBuilder msgOrBuilder) {
MessageLite msg = ProtoUtils.build(msgOrBuilder);
return new WithInputBuilder(
svc,
method,
List.of(InvocationInput.of(headerFromMessage(msg).copyWithFlags(flags), msg)));
}

WithInputBuilder withInput(MessageLiteOrBuilder... messages) {
public WithInputBuilder withInput(MessageLiteOrBuilder... messages) {
return new WithInputBuilder(
svc,
method,
Expand All @@ -187,21 +197,21 @@ WithInputBuilder withInput(MessageLiteOrBuilder... messages) {
}
}

static class WithInputBuilder extends TestInvocationBuilder {
public static class WithInputBuilder extends TestInvocationBuilder {
private final List<InvocationInput> input;

WithInputBuilder(BindableService svc, String method, List<InvocationInput> input) {
super(svc, method);
this.input = new ArrayList<>(input);
}

WithInputBuilder withInput(short flags, MessageLiteOrBuilder msgOrBuilder) {
public WithInputBuilder withInput(short flags, MessageLiteOrBuilder msgOrBuilder) {
MessageLite msg = ProtoUtils.build(msgOrBuilder);
this.input.add(InvocationInput.of(headerFromMessage(msg).copyWithFlags(flags), msg));
return this;
}

WithInputBuilder withInput(MessageLiteOrBuilder... messages) {
public WithInputBuilder withInput(MessageLiteOrBuilder... messages) {
this.input.addAll(
Arrays.stream(messages)
.map(
Expand All @@ -213,18 +223,18 @@ WithInputBuilder withInput(MessageLiteOrBuilder... messages) {
return this;
}

UsingThreadingModelsBuilder usingThreadingModels(ThreadingModel... threadingModels) {
public UsingThreadingModelsBuilder usingThreadingModels(ThreadingModel... threadingModels) {
return new UsingThreadingModelsBuilder(
this.svc, this.method, input, new HashSet<>(Arrays.asList(threadingModels)));
}

UsingThreadingModelsBuilder usingAllThreadingModels() {
public UsingThreadingModelsBuilder usingAllThreadingModels() {
return usingThreadingModels(ThreadingModel.values());
}
}

static class UsingThreadingModelsBuilder {
private final BindableService svc;
public static class UsingThreadingModelsBuilder {
private final @Nullable BindableService svc;
private final String method;
private final List<InvocationInput> input;
private final HashSet<ThreadingModel> threadingModels;
Expand All @@ -240,42 +250,47 @@ static class UsingThreadingModelsBuilder {
this.threadingModels = threadingModels;
}

ExpectingOutputMessages expectingOutput(MessageLiteOrBuilder... messages) {
public ExpectingOutputMessages expectingOutput(MessageLiteOrBuilder... messages) {
List<MessageLite> builtMessages =
Arrays.stream(messages).map(ProtoUtils::build).collect(Collectors.toList());
return assertingOutput(actual -> assertThat(actual).asList().isEqualTo(builtMessages));
}

ExpectingOutputMessages assertingOutput(Consumer<List<MessageLite>> messages) {
public ExpectingOutputMessages assertingOutput(Consumer<List<MessageLite>> messages) {
return new ExpectingOutputMessages(svc, method, input, threadingModels, messages);
}

ExpectingFailure assertingFailure(Class<? extends Throwable> tClass) {
public ExpectingFailure assertingFailure(Class<? extends Throwable> tClass) {
return assertingFailure(t -> assertThat(t).isInstanceOf(tClass));
}

ExpectingFailure assertingFailure(Consumer<Throwable> assertFailure) {
public ExpectingFailure assertingFailure(Consumer<Throwable> assertFailure) {
return new ExpectingFailure(svc, method, input, threadingModels, assertFailure);
}
}

public abstract static class BaseTestDefinition implements TestDefinition {
protected final BindableService svc;
protected final @Nullable BindableService svc;
protected final String method;
protected final List<InvocationInput> input;
protected final HashSet<ThreadingModel> threadingModels;
protected final String named;

public BaseTestDefinition(
BindableService svc,
@Nullable BindableService svc,
String method,
List<InvocationInput> input,
HashSet<ThreadingModel> threadingModels) {
this(svc, method, input, threadingModels, svc.getClass().getSimpleName());
this(
svc,
method,
input,
threadingModels,
svc != null ? svc.getClass().getSimpleName() : "invalid");
}

public BaseTestDefinition(
BindableService svc,
@Nullable BindableService svc,
String method,
List<InvocationInput> input,
HashSet<ThreadingModel> threadingModels,
Expand All @@ -289,7 +304,7 @@ public BaseTestDefinition(

@Override
public ServerServiceDefinition getService() {
return svc.bindService();
return Objects.requireNonNull(svc).bindService();
}

@Override
Expand All @@ -313,11 +328,11 @@ public String testCaseName() {
}
}

static class ExpectingOutputMessages extends BaseTestDefinition {
public static class ExpectingOutputMessages extends BaseTestDefinition {
private final Consumer<List<MessageLite>> messagesAssert;

ExpectingOutputMessages(
BindableService svc,
@Nullable BindableService svc,
String method,
List<InvocationInput> input,
HashSet<ThreadingModel> threadingModels,
Expand All @@ -327,7 +342,7 @@ static class ExpectingOutputMessages extends BaseTestDefinition {
}

ExpectingOutputMessages(
BindableService svc,
@Nullable BindableService svc,
String method,
List<InvocationInput> input,
HashSet<ThreadingModel> threadingModels,
Expand All @@ -337,14 +352,14 @@ static class ExpectingOutputMessages extends BaseTestDefinition {
this.messagesAssert = messagesAssert;
}

ExpectingOutputMessages named(String name) {
public ExpectingOutputMessages named(String name) {
return new TestCaseBuilder.ExpectingOutputMessages(
svc,
method,
input,
threadingModels,
messagesAssert,
svc.getClass().getSimpleName() + ": " + name);
svc != null ? (svc.getClass().getSimpleName() + ": " + name) : "invalid");
}

@Override
Expand All @@ -366,13 +381,18 @@ public BiConsumer<FutureSubscriber<MessageLite>, Duration> getOutputAssert() {
Protocol.ErrorMessage.class);
};
}

@Override
public boolean isValid() {
return this.svc != null;
}
}

static class ExpectingFailure extends BaseTestDefinition {
public static class ExpectingFailure extends BaseTestDefinition {
private final Consumer<Throwable> throwableAssert;

ExpectingFailure(
BindableService svc,
@Nullable BindableService svc,
String method,
List<InvocationInput> input,
HashSet<ThreadingModel> threadingModels,
Expand All @@ -392,14 +412,14 @@ static class ExpectingFailure extends BaseTestDefinition {
this.throwableAssert = throwableAssert;
}

ExpectingFailure named(String name) {
public ExpectingFailure named(String name) {
return new ExpectingFailure(
svc,
method,
input,
threadingModels,
throwableAssert,
svc.getClass().getSimpleName() + ": " + name);
svc != null ? (svc.getClass().getSimpleName() + ": " + name) : "invalid");
}

@Override
Expand All @@ -416,6 +436,11 @@ public BiConsumer<FutureSubscriber<MessageLite>, Duration> getOutputAssert() {
Protocol.OutputStreamEntryMessage.class, Protocol.SuspensionMessage.class);
};
}

@Override
public boolean isValid() {
return this.svc != null;
}
}
}
}
Loading

0 comments on commit 8223cd6

Please sign in to comment.