Skip to content

Commit

Permalink
Replace vertx.executeBlocking with manually provided executor (#130)
Browse files Browse the repository at this point in the history
  • Loading branch information
slinkydeveloper authored Oct 30, 2023
1 parent ccff20a commit 26105f1
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,14 @@
import io.netty.util.AsciiString;
import io.opentelemetry.api.OpenTelemetry;
import io.opentelemetry.context.propagation.TextMapGetter;
import io.vertx.core.Context;
import io.vertx.core.Handler;
import io.vertx.core.MultiMap;
import io.vertx.core.*;
import io.vertx.core.buffer.Buffer;
import io.vertx.core.http.HttpMethod;
import io.vertx.core.http.HttpServerRequest;
import io.vertx.core.http.HttpServerResponse;
import io.vertx.core.http.impl.HttpServerRequestInternal;
import java.net.URI;
import java.util.HashMap;
import java.util.HashSet;
import java.util.concurrent.Executor;
import java.util.regex.Pattern;
import javax.annotation.Nullable;
Expand Down Expand Up @@ -57,18 +54,15 @@ public String get(@Nullable MultiMap carrier, String key) {
};

private final RestateGrpcServer restateGrpcServer;
private final HashSet<String> blockingServices;
private final HashMap<String, Executor> executors;
private final HashMap<String, Executor> blockingServices;
private final OpenTelemetry openTelemetry;

RequestHttpServerHandler(
RestateGrpcServer restateGrpcServer,
HashSet<String> blockingServices,
HashMap<String, Executor> executors,
HashMap<String, Executor> blockingServices,
OpenTelemetry openTelemetry) {
this.restateGrpcServer = restateGrpcServer;
this.blockingServices = blockingServices;
this.executors = executors;
this.openTelemetry = openTelemetry;
}

Expand All @@ -91,7 +85,7 @@ public void handle(HttpServerRequest request) {
}
String serviceName = pathSegments[pathSegments.length - 2];
String methodName = pathSegments[pathSegments.length - 1];
boolean isBlockingService = blockingServices.contains(serviceName);
boolean isBlockingService = blockingServices.containsKey(serviceName);

// Parse OTEL context and generate span
final io.opentelemetry.context.Context otelContext =
Expand All @@ -113,7 +107,7 @@ public void handle(HttpServerRequest request) {
methodName,
otelContext,
isBlockingService ? currentContextExecutor(vertxCurrentContext) : null,
isBlockingService ? blockingExecutor(serviceName, vertxCurrentContext) : null);
isBlockingService ? blockingExecutor(serviceName) : null);
} catch (ProtocolException e) {
LOG.warn("Error when resolving the grpc handler", e);
request
Expand Down Expand Up @@ -149,22 +143,13 @@ private Executor currentContextExecutor(Context currentContext) {
return runnable -> currentContext.runOnContext(v -> runnable.run());
}

private Executor blockingExecutor(String serviceName, Context currentContext) {
if (this.executors.containsKey(serviceName)) {
Executor userExecutor = this.executors.get(serviceName);
return runnable -> {
// We need to propagate the gRPC context!
io.grpc.Context ctx = io.grpc.Context.current();
userExecutor.execute(() -> ctx.run(runnable));
};
}
return runnable ->
currentContext.executeBlocking(
() -> {
runnable.run();
return null;
},
false);
private Executor blockingExecutor(String serviceName) {
Executor userExecutor = this.blockingServices.get(serviceName);
return runnable -> {
// We need to propagate the gRPC context!
io.grpc.Context ctx = io.grpc.Context.current();
userExecutor.execute(ctx.wrap(runnable));
};
}

private void handleDiscoveryRequest(HttpServerRequest request) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import io.vertx.core.http.HttpServerOptions;
import java.util.*;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

Expand All @@ -40,8 +41,8 @@ public class RestateHttpEndpointBuilder {
private final Vertx vertx;
private final RestateGrpcServer.Builder restateGrpcServerBuilder =
RestateGrpcServer.newBuilder(Discovery.ProtocolMode.BIDI_STREAM);
private final HashSet<String> blockingServices = new HashSet<>();
private final HashMap<String, Executor> executors = new HashMap<>();
private final Executor defaultExecutor = Executors.newCachedThreadPool();
private final HashMap<String, Executor> blockingServices = new HashMap<>();
private OpenTelemetry openTelemetry = OpenTelemetry.noop();
private HttpServerOptions options =
new HttpServerOptions()
Expand Down Expand Up @@ -76,11 +77,7 @@ public RestateHttpEndpointBuilder withOptions(HttpServerOptions options) {
*/
public RestateHttpEndpointBuilder withService(
BindableBlockingService service, ServerInterceptor... interceptors) {
ServerServiceDefinition definition =
ServerInterceptors.intercept(service, Arrays.asList(interceptors));
this.restateGrpcServerBuilder.withService(definition);
this.blockingServices.add(definition.getServiceDescriptor().getName());
return this;
return this.withService(service, defaultExecutor, interceptors);
}

/**
Expand All @@ -95,8 +92,7 @@ public RestateHttpEndpointBuilder withService(
ServerServiceDefinition definition =
ServerInterceptors.intercept(service, Arrays.asList(interceptors));
this.restateGrpcServerBuilder.withService(definition);
this.blockingServices.add(definition.getServiceDescriptor().getName());
this.executors.put(definition.getServiceDescriptor().getName(), executor);
this.blockingServices.put(definition.getServiceDescriptor().getName(), executor);
return this;
}

Expand Down Expand Up @@ -157,7 +153,7 @@ public HttpServer build() {

server.requestHandler(
new RequestHttpServerHandler(
this.restateGrpcServerBuilder.build(), blockingServices, executors, openTelemetry));
this.restateGrpcServerBuilder.build(), blockingServices, openTelemetry));

return server;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package dev.restate.sdk.http.vertx

import com.google.protobuf.ByteString
import dev.restate.generated.sdk.java.Java.SideEffectEntryMessage
import dev.restate.sdk.blocking.StateTest
import dev.restate.sdk.blocking.RestateBlockingService
import dev.restate.sdk.core.impl.ProtoUtils.*
import dev.restate.sdk.core.impl.TestDefinitions.*
import dev.restate.sdk.core.impl.TestRunner
Expand All @@ -11,6 +11,7 @@ import dev.restate.sdk.core.impl.testservices.GreeterGrpcKt
import dev.restate.sdk.core.impl.testservices.GreetingRequest
import dev.restate.sdk.core.impl.testservices.GreetingResponse
import dev.restate.sdk.kotlin.RestateCoroutineService
import io.grpc.stub.StreamObserver
import io.vertx.core.Vertx
import java.util.stream.Stream
import kotlinx.coroutines.Dispatchers
Expand All @@ -35,8 +36,8 @@ class HttpVertxTests : TestRunner() {
return Stream.of(HttpVertxTestExecutor(vertx))
}

class VertxKotlinTest : TestSuite {
private class CheckCorrectThread :
class VertxExecutorsTest : TestSuite {
private class CheckNonBlockingServiceTrampolineEventLoopContext :
GreeterGrpcKt.GreeterCoroutineImplBase(Dispatchers.Unconfined), RestateCoroutineService {
override suspend fun greet(request: GreetingRequest): GreetingResponse {
check(Vertx.currentContext().isEventLoopContext)
Expand All @@ -46,9 +47,35 @@ class HttpVertxTests : TestRunner() {
}
}

private class CheckBlockingServiceTrampolineExecutor :
GreeterGrpc.GreeterImplBase(), RestateBlockingService {
override fun greet(
request: GreetingRequest,
responseObserver: StreamObserver<GreetingResponse>
) {
val id = Thread.currentThread().id
check(Vertx.currentContext() == null)
restateContext().sideEffect {
check(Thread.currentThread().id == id)
check(Vertx.currentContext() == null)
}
check(Thread.currentThread().id == id)
check(Vertx.currentContext() == null)
responseObserver.onNext(GreetingResponse.getDefaultInstance())
responseObserver.onCompleted()
}
}

override fun definitions(): Stream<TestDefinition> {
return Stream.of(
testInvocation(CheckCorrectThread(), GreeterGrpc.getGreetMethod())
testInvocation(
CheckNonBlockingServiceTrampolineEventLoopContext(), GreeterGrpc.getGreetMethod())
.withInput(startMessage(1), inputMessage(GreetingRequest.getDefaultInstance()))
.onlyUnbuffered()
.expectingOutput(
SideEffectEntryMessage.newBuilder().setValue(ByteString.EMPTY),
outputMessage(GreetingResponse.getDefaultInstance())),
testInvocation(CheckBlockingServiceTrampolineExecutor(), GreeterGrpc.getGreetMethod())
.withInput(startMessage(1), inputMessage(GreetingRequest.getDefaultInstance()))
.onlyUnbuffered()
.expectingOutput(
Expand All @@ -57,23 +84,12 @@ class HttpVertxTests : TestRunner() {
}
}

// Assert unconfined dispatcher
private class CheckCorrectThread :
GreeterGrpcKt.GreeterCoroutineImplBase(Dispatchers.Unconfined), RestateCoroutineService {
override suspend fun greet(request: GreetingRequest): GreetingResponse {
check(Vertx.currentContext().isEventLoopContext)
restateContext().sideEffect { check(Vertx.currentContext().isEventLoopContext) }
check(Vertx.currentContext().isEventLoopContext)
return GreetingResponse.getDefaultInstance()
}
}

override fun definitions(): Stream<TestSuite> {
return Stream.of(
dev.restate.sdk.blocking.AwakeableIdTest(),
dev.restate.sdk.blocking.DeferredTest(),
dev.restate.sdk.blocking.EagerStateTest(),
StateTest(),
dev.restate.sdk.blocking.StateTest(),
dev.restate.sdk.blocking.InvocationIdTest(),
dev.restate.sdk.blocking.OnlyInputAndOutputTest(),
dev.restate.sdk.blocking.SideEffectTest(),
Expand All @@ -90,6 +106,6 @@ class HttpVertxTests : TestRunner() {
dev.restate.sdk.kotlin.SleepTest(),
dev.restate.sdk.kotlin.StateMachineFailuresTest(),
dev.restate.sdk.kotlin.UserFailuresTest(),
VertxKotlinTest())
VertxExecutorsTest())
}
}

0 comments on commit 26105f1

Please sign in to comment.