diff --git a/sdk-http-vertx/src/main/java/dev/restate/sdk/http/vertx/RequestHttpServerHandler.java b/sdk-http-vertx/src/main/java/dev/restate/sdk/http/vertx/RequestHttpServerHandler.java index 8c3ed3dc..5c5a033f 100644 --- a/sdk-http-vertx/src/main/java/dev/restate/sdk/http/vertx/RequestHttpServerHandler.java +++ b/sdk-http-vertx/src/main/java/dev/restate/sdk/http/vertx/RequestHttpServerHandler.java @@ -20,6 +20,7 @@ import io.vertx.core.http.HttpServerResponse; import io.vertx.core.http.impl.HttpServerRequestInternal; import java.net.URI; +import java.util.HashMap; import java.util.HashSet; import java.util.concurrent.Executor; import java.util.regex.Pattern; @@ -57,14 +58,17 @@ public String get(@Nullable MultiMap carrier, String key) { private final RestateGrpcServer restateGrpcServer; private final HashSet blockingServices; + private final HashMap executors; private final OpenTelemetry openTelemetry; RequestHttpServerHandler( RestateGrpcServer restateGrpcServer, HashSet blockingServices, + HashMap executors, OpenTelemetry openTelemetry) { this.restateGrpcServer = restateGrpcServer; this.blockingServices = blockingServices; + this.executors = executors; this.openTelemetry = openTelemetry; } @@ -85,9 +89,9 @@ public void handle(HttpServerRequest request) { request.response().setStatusCode(NOT_FOUND.code()).end(); return; } - String service = pathSegments[pathSegments.length - 2]; - String method = pathSegments[pathSegments.length - 1]; - boolean isBlockingService = blockingServices.contains(service); + String serviceName = pathSegments[pathSegments.length - 2]; + String methodName = pathSegments[pathSegments.length - 1]; + boolean isBlockingService = blockingServices.contains(serviceName); // Parse OTEL context and generate span final io.opentelemetry.context.Context otelContext = @@ -105,11 +109,11 @@ public void handle(HttpServerRequest request) { try { handler = restateGrpcServer.resolve( - service, - method, + serviceName, + methodName, otelContext, isBlockingService ? currentContextExecutor(vertxCurrentContext) : null, - isBlockingService ? blockingExecutor(vertxCurrentContext) : null); + isBlockingService ? blockingExecutor(serviceName, vertxCurrentContext) : null); } catch (ProtocolException e) { LOG.warn("Error when resolving the grpc handler", e); request @@ -122,7 +126,7 @@ public void handle(HttpServerRequest request) { return; } - LOG.debug("Handling request to " + service + "/" + method); + LOG.debug("Handling request to " + serviceName + "/" + methodName); // Prepare the header frame to send in the response. // Vert.x will send them as soon as we send the first write @@ -145,17 +149,20 @@ private Executor currentContextExecutor(Context currentContext) { return runnable -> currentContext.runOnContext(v -> runnable.run()); } - private Executor blockingExecutor(Context currentContext) { + private Executor blockingExecutor(String serviceName, Context currentContext) { + if (this.executors.containsKey(serviceName)) { + Executor userExecutor = this.executors.get(serviceName); + return runnable -> { + // We need to propagate the gRPC context! + io.grpc.Context ctx = io.grpc.Context.current(); + userExecutor.execute(() -> ctx.run(runnable)); + }; + } return runnable -> currentContext.executeBlocking( - promise -> { - try { - runnable.run(); - } catch (Throwable e) { - promise.fail(e); - return; - } - promise.complete(); + () -> { + runnable.run(); + return null; }, false); } diff --git a/sdk-http-vertx/src/main/java/dev/restate/sdk/http/vertx/RestateHttpEndpointBuilder.java b/sdk-http-vertx/src/main/java/dev/restate/sdk/http/vertx/RestateHttpEndpointBuilder.java index 7b60d0dc..9188cbd9 100644 --- a/sdk-http-vertx/src/main/java/dev/restate/sdk/http/vertx/RestateHttpEndpointBuilder.java +++ b/sdk-http-vertx/src/main/java/dev/restate/sdk/http/vertx/RestateHttpEndpointBuilder.java @@ -13,10 +13,8 @@ import io.vertx.core.Vertx; import io.vertx.core.http.HttpServer; import io.vertx.core.http.HttpServerOptions; -import java.util.Arrays; -import java.util.HashSet; -import java.util.Objects; -import java.util.Optional; +import java.util.*; +import java.util.concurrent.Executor; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -43,6 +41,7 @@ public class RestateHttpEndpointBuilder { private final RestateGrpcServer.Builder restateGrpcServerBuilder = RestateGrpcServer.newBuilder(Discovery.ProtocolMode.BIDI_STREAM); private final HashSet blockingServices = new HashSet<>(); + private final HashMap executors = new HashMap<>(); private OpenTelemetry openTelemetry = OpenTelemetry.noop(); private HttpServerOptions options = new HttpServerOptions() @@ -84,6 +83,23 @@ public RestateHttpEndpointBuilder withService( return this; } + /** + * Add a {@link BindableBlockingService} to the endpoint, specifying the {@code executor} where to + * run the service code. + * + *

You can run on virtual threads by using the executor {@code + * Executors.newVirtualThreadPerTaskExecutor()}. + */ + public RestateHttpEndpointBuilder withService( + BindableBlockingService service, Executor executor, ServerInterceptor... interceptors) { + ServerServiceDefinition definition = + ServerInterceptors.intercept(service, Arrays.asList(interceptors)); + this.restateGrpcServerBuilder.withService(definition); + this.blockingServices.add(definition.getServiceDescriptor().getName()); + this.executors.put(definition.getServiceDescriptor().getName(), executor); + return this; + } + /** * Add a {@link BindableNonBlockingService} to the endpoint. * @@ -141,7 +157,7 @@ public HttpServer build() { server.requestHandler( new RequestHttpServerHandler( - this.restateGrpcServerBuilder.build(), blockingServices, openTelemetry)); + this.restateGrpcServerBuilder.build(), blockingServices, executors, openTelemetry)); return server; }