Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(interactive): Support Http Gremlin Service #4394

Merged
merged 4 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.FullHttpMessage;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.util.Attribute;
import io.netty.util.AttributeMap;

Expand All @@ -42,6 +44,7 @@
import org.apache.tinkerpop.gremlin.server.auth.Authenticator;
import org.apache.tinkerpop.gremlin.server.authz.Authorizer;
import org.apache.tinkerpop.gremlin.server.handler.AbstractAuthenticationHandler;
import org.apache.tinkerpop.gremlin.server.handler.HttpHandlerUtils;
import org.apache.tinkerpop.gremlin.server.handler.SaslAuthenticationHandler;
import org.apache.tinkerpop.gremlin.server.handler.StateKey;
import org.slf4j.Logger;
Expand All @@ -50,6 +53,7 @@
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.nio.charset.Charset;
import java.util.Base64;
import java.util.HashMap;
import java.util.Map;
Expand Down Expand Up @@ -80,140 +84,201 @@
this.settings = settings;
}

@Override
public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception {
if (msg instanceof RequestMessage) {
final RequestMessage requestMessage = (RequestMessage) msg;

final Attribute<Authenticator.SaslNegotiator> negotiator =
((AttributeMap) ctx).attr(StateKey.NEGOTIATOR);
final Attribute<RequestMessage> request =
((AttributeMap) ctx).attr(StateKey.REQUEST_MESSAGE);

if (negotiator.get() == null) {
try {
// First time through so save the request and send an AUTHENTICATE challenge
// with no data
negotiator.set(authenticator.newSaslNegotiator(getRemoteInetAddress(ctx)));
request.set(requestMessage);
// the authentication flag is off, just pass the original message down the
// pipeline for processing
if (!authenticator.requireAuthentication()) {
ctx.pipeline().remove(this);
final RequestMessage original = request.get();
ctx.fireChannelRead(original);
} else {
final ResponseMessage authenticate =
ResponseMessage.build(requestMessage)
.code(ResponseStatusCode.AUTHENTICATE)
.create();
ctx.writeAndFlush(authenticate);
}
} catch (Exception ex) {
// newSaslNegotiator can cause troubles - if we don't catch and respond nicely
// the driver seems
// to hang until timeout which isn't so nice. treating this like a server error
// as it means that
// the Authenticator isn't really ready to deal with requests for some reason.
logger.error(
String.format(
"%s is not ready to handle requests - check its configuration"
+ " or related services",
authenticator.getClass().getSimpleName()),
ex);

final ResponseMessage error =
ResponseMessage.build(requestMessage)
.statusMessage("Authenticator is not ready to handle requests")
.code(ResponseStatusCode.SERVER_ERROR)
.create();
ctx.writeAndFlush(error);
}
} else {
if (requestMessage.getOp().equals(Tokens.OPS_AUTHENTICATION)
&& requestMessage.getArgs().containsKey(Tokens.ARGS_SASL)) {

final Object saslObject = requestMessage.getArgs().get(Tokens.ARGS_SASL);
final byte[] saslResponse;

if (saslObject instanceof String) {
saslResponse = BASE64_DECODER.decode((String) saslObject);
} else {
final ResponseMessage error =
ResponseMessage.build(request.get())
.statusMessage(
"Incorrect type for : "
+ Tokens.ARGS_SASL
+ " - base64 encoded String is expected")
.code(ResponseStatusCode.REQUEST_ERROR_MALFORMED_REQUEST)
.create();
ctx.writeAndFlush(error);
return;
}

try {
final byte[] saslMessage = negotiator.get().evaluateResponse(saslResponse);
if (negotiator.get().isComplete()) {
final AuthenticatedUser user = negotiator.get().getAuthenticatedUser();
ctx.channel().attr(StateKey.AUTHENTICATED_USER).set(user);
// Username logged with the remote socket address and authenticator
// classname for audit logging
if (settings.enableAuditLog || settings.authentication.enableAuditLog) {
String address = ctx.channel().remoteAddress().toString();
if (address.startsWith("/") && address.length() > 1)
address = address.substring(1);
final String[] authClassParts =
authenticator.getClass().toString().split("[.]");
auditLogger.info(
"User {} with address {} authenticated by {}",
user.getName(),
address,
authClassParts[authClassParts.length - 1]);
}
// If we have got here we are authenticated so remove the handler and
// pass
// the original message down the pipeline for processing
ctx.pipeline().remove(this);
final RequestMessage original = request.get();
ctx.fireChannelRead(original);
} else {
// not done here - send back the sasl message for next challenge.
final Map<String, Object> metadata = new HashMap<>();
metadata.put(
Tokens.ARGS_SASL, BASE64_ENCODER.encodeToString(saslMessage));
final ResponseMessage authenticate =
ResponseMessage.build(requestMessage)
.statusAttributes(metadata)
.code(ResponseStatusCode.AUTHENTICATE)
.create();
ctx.writeAndFlush(authenticate);
}
} catch (AuthenticationException ae) {
final ResponseMessage error =
ResponseMessage.build(request.get())
.statusMessage(ae.getMessage())
.code(ResponseStatusCode.UNAUTHORIZED)
.create();
ctx.writeAndFlush(error);
}
} else {
final ResponseMessage error =
ResponseMessage.build(requestMessage)
.statusMessage("Failed to authenticate")
.code(ResponseStatusCode.UNAUTHORIZED)
.create();
ctx.writeAndFlush(error);
}
}
} else if (msg instanceof FullHttpMessage) { // add Authentication for HTTP requests
FullHttpMessage request = (FullHttpMessage) msg;

if (!authenticator.requireAuthentication()) {
ctx.fireChannelRead(request);
return;
}

String errorMsg =
"Invalid HTTP Header for Authentication. Expected format: 'Authorization: Basic"
+ " <Base64(user:password)>'";

if (!request.headers().contains("Authorization")) {
sendError(ctx, errorMsg, request);
return;
}

String authorizationHeader = request.headers().get("Authorization");
if (!authorizationHeader.startsWith("Basic ")) {
sendError(ctx, errorMsg, request);
return;
}

String authorization;
byte[] decodedUserPass;
try {
authorization = authorizationHeader.substring("Basic ".length());
decodedUserPass = BASE64_DECODER.decode(authorization);
} catch (Exception e) {
sendError(ctx, errorMsg, request);
return;
}

authorization = new String(decodedUserPass, Charset.forName("UTF-8"));
String[] split = authorization.split(":");
if (split.length != 2) {
sendError(
ctx,
"Invalid username or password after decoding the Base64 Authorization"
+ " header.",
request);
return;
}

Map<String, String> credentials = new HashMap();
credentials.put("username", split[0]);
credentials.put("password", split[1]);
String address = ctx.channel().remoteAddress().toString();
if (address.startsWith("/") && address.length() > 1) {
address = address.substring(1);
}

credentials.put("address", address);

try {
AuthenticatedUser user = authenticator.authenticate(credentials);
ctx.channel().attr(StateKey.AUTHENTICATED_USER).set(user);
ctx.fireChannelRead(request);
} catch (AuthenticationException e) {
sendError(ctx, e.getMessage(), request);
}
} else {
logger.warn(
"{} only processes RequestMessage instances - received {} - channel closing",
"{} received invalid request message {} - channel closing",
this.getClass().getSimpleName(),
msg.getClass());
ctx.close();
}
}

Check notice on line 281 in interactive_engine/compiler/src/main/java/com/alibaba/graphscope/gremlin/auth/IrAuthenticationHandler.java

View check run for this annotation

codefactor.io / CodeFactor

interactive_engine/compiler/src/main/java/com/alibaba/graphscope/gremlin/auth/IrAuthenticationHandler.java#L87-L281

Complex Method
private InetAddress getRemoteInetAddress(final ChannelHandlerContext ctx) {
final Channel channel = ctx.channel();

Expand All @@ -226,4 +291,17 @@

return ((InetSocketAddress) genericSocketAddr).getAddress();
}

private void sendError(
final ChannelHandlerContext ctx, String errorMsg, FullHttpMessage request) {
HttpHandlerUtils.sendError(ctx, HttpResponseStatus.UNAUTHORIZED, errorMsg, false);
if (request.refCnt() > 0) {
boolean fullyReleased = request.release();
if (!fullyReleased) {
logger.warn(
"http request message was not fully released, may cause a"
+ " memory leak");
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*
*
* * Copyright 2020 Alibaba Group Holding Limited.
* *
* * Licensed under the Apache License, Version 2.0 (the "License");
* * you may not use this file except in compliance with the License.
* * You may obtain a copy of the License at
* *
* * http://www.apache.org/licenses/LICENSE-2.0
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS,
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* * See the License for the specific language governing permissions and
* * limitations under the License.
*
*/

package com.alibaba.graphscope.gremlin.plugin.processor;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.*;

import org.apache.tinkerpop.gremlin.driver.message.RequestMessage;
import org.apache.tinkerpop.gremlin.driver.message.ResponseMessage;
import org.apache.tinkerpop.gremlin.driver.message.ResponseStatusCode;
import org.apache.tinkerpop.gremlin.driver.ser.MessageTextSerializer;
import org.apache.tinkerpop.gremlin.driver.ser.SerializationException;
import org.apache.tinkerpop.gremlin.groovy.engine.GremlinExecutor;
import org.apache.tinkerpop.gremlin.server.Context;
import org.apache.tinkerpop.gremlin.server.GraphManager;
import org.apache.tinkerpop.gremlin.server.Settings;
import org.apache.tinkerpop.gremlin.server.handler.HttpHandlerUtils;
import org.javatuples.Pair;

import java.nio.charset.StandardCharsets;
import java.util.Objects;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.atomic.AtomicReference;

/**
* Maintain the gremlin execution context for http request.
*/
public class HttpContext extends Context {
private final Pair<String, MessageTextSerializer<?>> serializer;
private final boolean keepAlive;
private final AtomicReference<Boolean> headerSent;

public HttpContext(
RequestMessage requestMessage,
ChannelHandlerContext ctx,
Settings settings,
GraphManager graphManager,
GremlinExecutor gremlinExecutor,
ScheduledExecutorService scheduledExecutorService,
Pair<String, MessageTextSerializer<?>> serializer,
boolean keepAlive) {
super(
requestMessage,
ctx,
settings,
graphManager,
gremlinExecutor,
scheduledExecutorService);
this.serializer = Objects.requireNonNull(serializer);
this.keepAlive = keepAlive;
this.headerSent = new AtomicReference<>(false);
}

/**
* serialize the response message to http response and write to http channel.
* @param responseMessage
*/
@Override
public void writeAndFlush(final ResponseMessage responseMessage) {
try {
// send header once
if (!headerSent.compareAndSet(false, true)) {
FullHttpResponse chunkedResponse =
new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK);
chunkedResponse.headers().set(HttpHeaderNames.CONTENT_TYPE, serializer.getValue0());
chunkedResponse
.headers()
.set(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED);
this.getChannelHandlerContext().writeAndFlush(chunkedResponse);
}
ByteBuf byteBuf =
Unpooled.wrappedBuffer(
serializer
.getValue1()
.serializeResponseAsString(responseMessage)
.getBytes(StandardCharsets.UTF_8));
FullHttpResponse response =
new DefaultFullHttpResponse(
HttpVersion.HTTP_1_1, HttpResponseStatus.OK, byteBuf);
ChannelFuture channelFuture = this.getChannelHandlerContext().writeAndFlush(response);
ResponseStatusCode statusCode = responseMessage.getStatus().getCode();
if (!keepAlive && statusCode.isFinalResponse()) {
channelFuture.addListener(ChannelFutureListener.CLOSE);
}
} catch (SerializationException e) {
HttpHandlerUtils.sendError(
this.getChannelHandlerContext(),
HttpResponseStatus.INTERNAL_SERVER_ERROR,
e.getMessage(),
keepAlive);
}
}
}
Loading
Loading