Skip to content

Commit

Permalink
Add support to provide your own executor (#114)
Browse files Browse the repository at this point in the history
  • Loading branch information
slinkydeveloper authored Oct 16, 2023
1 parent 529f59f commit 0058656
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import io.vertx.core.http.HttpServerResponse;
import io.vertx.core.http.impl.HttpServerRequestInternal;
import java.net.URI;
import java.util.HashMap;
import java.util.HashSet;
import java.util.concurrent.Executor;
import java.util.regex.Pattern;
Expand Down Expand Up @@ -57,14 +58,17 @@ public String get(@Nullable MultiMap carrier, String key) {

private final RestateGrpcServer restateGrpcServer;
private final HashSet<String> blockingServices;
private final HashMap<String, Executor> executors;
private final OpenTelemetry openTelemetry;

RequestHttpServerHandler(
RestateGrpcServer restateGrpcServer,
HashSet<String> blockingServices,
HashMap<String, Executor> executors,
OpenTelemetry openTelemetry) {
this.restateGrpcServer = restateGrpcServer;
this.blockingServices = blockingServices;
this.executors = executors;
this.openTelemetry = openTelemetry;
}

Expand All @@ -85,9 +89,9 @@ public void handle(HttpServerRequest request) {
request.response().setStatusCode(NOT_FOUND.code()).end();
return;
}
String service = pathSegments[pathSegments.length - 2];
String method = pathSegments[pathSegments.length - 1];
boolean isBlockingService = blockingServices.contains(service);
String serviceName = pathSegments[pathSegments.length - 2];
String methodName = pathSegments[pathSegments.length - 1];
boolean isBlockingService = blockingServices.contains(serviceName);

// Parse OTEL context and generate span
final io.opentelemetry.context.Context otelContext =
Expand All @@ -105,11 +109,11 @@ public void handle(HttpServerRequest request) {
try {
handler =
restateGrpcServer.resolve(
service,
method,
serviceName,
methodName,
otelContext,
isBlockingService ? currentContextExecutor(vertxCurrentContext) : null,
isBlockingService ? blockingExecutor(vertxCurrentContext) : null);
isBlockingService ? blockingExecutor(serviceName, vertxCurrentContext) : null);
} catch (ProtocolException e) {
LOG.warn("Error when resolving the grpc handler", e);
request
Expand All @@ -122,7 +126,7 @@ public void handle(HttpServerRequest request) {
return;
}

LOG.debug("Handling request to " + service + "/" + method);
LOG.debug("Handling request to " + serviceName + "/" + methodName);

// Prepare the header frame to send in the response.
// Vert.x will send them as soon as we send the first write
Expand All @@ -145,17 +149,20 @@ private Executor currentContextExecutor(Context currentContext) {
return runnable -> currentContext.runOnContext(v -> runnable.run());
}

private Executor blockingExecutor(Context currentContext) {
private Executor blockingExecutor(String serviceName, Context currentContext) {
if (this.executors.containsKey(serviceName)) {
Executor userExecutor = this.executors.get(serviceName);
return runnable -> {
// We need to propagate the gRPC context!
io.grpc.Context ctx = io.grpc.Context.current();
userExecutor.execute(() -> ctx.run(runnable));
};
}
return runnable ->
currentContext.executeBlocking(
promise -> {
try {
runnable.run();
} catch (Throwable e) {
promise.fail(e);
return;
}
promise.complete();
() -> {
runnable.run();
return null;
},
false);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,8 @@
import io.vertx.core.Vertx;
import io.vertx.core.http.HttpServer;
import io.vertx.core.http.HttpServerOptions;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Objects;
import java.util.Optional;
import java.util.*;
import java.util.concurrent.Executor;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

Expand All @@ -43,6 +41,7 @@ public class RestateHttpEndpointBuilder {
private final RestateGrpcServer.Builder restateGrpcServerBuilder =
RestateGrpcServer.newBuilder(Discovery.ProtocolMode.BIDI_STREAM);
private final HashSet<String> blockingServices = new HashSet<>();
private final HashMap<String, Executor> executors = new HashMap<>();
private OpenTelemetry openTelemetry = OpenTelemetry.noop();
private HttpServerOptions options =
new HttpServerOptions()
Expand Down Expand Up @@ -84,6 +83,23 @@ public RestateHttpEndpointBuilder withService(
return this;
}

/**
* Add a {@link BindableBlockingService} to the endpoint, specifying the {@code executor} where to
* run the service code.
*
* <p>You can run on virtual threads by using the executor {@code
* Executors.newVirtualThreadPerTaskExecutor()}.
*/
public RestateHttpEndpointBuilder withService(
BindableBlockingService service, Executor executor, ServerInterceptor... interceptors) {
ServerServiceDefinition definition =
ServerInterceptors.intercept(service, Arrays.asList(interceptors));
this.restateGrpcServerBuilder.withService(definition);
this.blockingServices.add(definition.getServiceDescriptor().getName());
this.executors.put(definition.getServiceDescriptor().getName(), executor);
return this;
}

/**
* Add a {@link BindableNonBlockingService} to the endpoint.
*
Expand Down Expand Up @@ -141,7 +157,7 @@ public HttpServer build() {

server.requestHandler(
new RequestHttpServerHandler(
this.restateGrpcServerBuilder.build(), blockingServices, openTelemetry));
this.restateGrpcServerBuilder.build(), blockingServices, executors, openTelemetry));

return server;
}
Expand Down

0 comments on commit 0058656

Please sign in to comment.