Skip to content

Commit

Permalink
Support Tools for Vertex AI Gemini
Browse files Browse the repository at this point in the history
  • Loading branch information
sberyozkin committed Dec 23, 2024
1 parent d56aa6b commit c7edcc5
Show file tree
Hide file tree
Showing 8 changed files with 150 additions and 24 deletions.
Original file line number Diff line number Diff line change
@@ -1,15 +1,25 @@
package io.quarkiverse.langchain4j.vertexai.runtime.gemini;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;

import com.fasterxml.jackson.core.JsonProcessingException;

import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.Content;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.TextContent;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.chat.request.json.JsonObjectSchema;
import dev.langchain4j.model.chat.request.json.JsonSchemaElementHelper;
import io.quarkiverse.langchain4j.QuarkusJsonCodecFactory;

final class ContentMapper {

Expand Down Expand Up @@ -37,30 +47,70 @@ static GenerateContentRequest map(List<ChatMessage> messages, List<ToolSpecifica
}
contents.add(new GenerateContentRequest.Content(role, parts));
} else if (message instanceof AiMessage am) {
if (am.hasToolExecutionRequests()) {
throw new IllegalArgumentException("The Gemini integration currently does not support tools");
try {
if (am.hasToolExecutionRequests()) {
List<ToolExecutionRequest> toolExecutionRequests = am.toolExecutionRequests();
List<FunctionCall> toolCalls = new ArrayList<>(toolExecutionRequests.size());
for (ToolExecutionRequest toolExecutionRequest : toolExecutionRequests) {
String argumentsStr = toolExecutionRequest.arguments();
String name = toolExecutionRequest.name();
// TODO: we need to update LangChain4j to make ToolExecutionRequest use a map instead of a String
Map<String, Object> arguments = QuarkusJsonCodecFactory.ObjectMapperHolder.MAPPER.readValue(
argumentsStr,
Map.class);
toolCalls.add(new FunctionCall(name, arguments));
}
contents.add(new GenerateContentRequest.Content(role,
List.of(new GenerateContentRequest.Content.Part(am.text(), toolCalls))));
} else {
contents.add(new GenerateContentRequest.Content(role,
List.of(GenerateContentRequest.Content.Part.ofText(am.text()))));
}
} catch (JsonProcessingException e) {
throw new IllegalStateException("Unable to perform conversion of tool response", e);
}
} else if (message instanceof ToolExecutionResultMessage) {
contents.add(new GenerateContentRequest.Content(role,
List.of(GenerateContentRequest.Content.Part.ofText(am.text()))));
List.of(GenerateContentRequest.Content.Part.ofText(message.text()))));
} else {
throw new IllegalArgumentException(
"The Gemini integration currently does not support " + message.type() + " messages");
}
}
}

List<GenerateContentRequest.Tool> tools;
if (toolSpecifications == null || toolSpecifications.isEmpty()) {
tools = null;
} else {
tools = new ArrayList<>(toolSpecifications.size());
for (GenerateContentRequest.Tool tool : tools) {
// TODO: implement
}
}

return new GenerateContentRequest(contents,
!systemPrompts.isEmpty() ? GenerateContentRequest.SystemInstruction.ofContent(systemPrompts) : null, tools,
!systemPrompts.isEmpty() ? GenerateContentRequest.SystemInstruction.ofContent(systemPrompts) : null,
toTools(toolSpecifications),
generationConfig);
}

static List<GenerateContentRequest.Tool> toTools(Collection<ToolSpecification> toolSpecifications) {
if (toolSpecifications == null) {
return null;
}
if (toolSpecifications.isEmpty()) {
return Collections.emptyList();
}
List<GenerateContentRequest.Tool> result = new ArrayList<>(toolSpecifications.size());
for (ToolSpecification toolSpecification : toolSpecifications) {
result.add(toTool(toolSpecification));
}
return result;
}

private static GenerateContentRequest.Tool toTool(ToolSpecification toolSpecification) {
FunctionDeclaration.Parameters functionParameters = toFunctionParameters(toolSpecification.parameters());

return new GenerateContentRequest.Tool(
new FunctionDeclaration(toolSpecification.name(), toolSpecification.description(), functionParameters));
}

private static FunctionDeclaration.Parameters toFunctionParameters(JsonObjectSchema parameters) {
if (parameters == null) {
return FunctionDeclaration.Parameters.empty();
}
return FunctionDeclaration.Parameters.objectType(JsonSchemaElementHelper.toMap(parameters.properties()),
parameters.required());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,21 @@

import java.util.Map;

public record FunctionCall(String name, Map<String, String> args) {
import com.fasterxml.jackson.core.JsonProcessingException;

import dev.langchain4j.agent.tool.ToolExecutionRequest;
import io.quarkiverse.langchain4j.QuarkusJsonCodecFactory;

public record FunctionCall(String name, Map<String, Object> arguments) {
public ToolExecutionRequest toToolExecutionRequest() {
try {
return ToolExecutionRequest.builder()
.name(name())
.arguments(QuarkusJsonCodecFactory.ObjectMapperHolder.MAPPER
.writeValueAsString(arguments()))
.build();
} catch (JsonProcessingException e) {
throw new RuntimeException("Unable to parse tool call response", e);
}
}
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
package io.quarkiverse.langchain4j.vertexai.runtime.gemini;

import java.util.Collections;
import java.util.List;
import java.util.Map;

public record FunctionDeclaration(String name, String description, Parameters parameters) {

public record Parameters(String type, Map<String, Map<String, Object>> properties, List<String> required) {

private static final String OBJECT_TYPE = "object";

public static Parameters objectType(Map<String, Map<String, Object>> properties, List<String> required) {
return new Parameters(OBJECT_TYPE, properties, required);
}

public static Parameters empty() {
return new Parameters(OBJECT_TYPE, Collections.emptyMap(), Collections.emptyList());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ public record GenerateContentRequest(List<Content> contents, SystemInstruction s

public record Content(String role, List<Part> parts) {

public record Part(String text, FunctionCall functionCall) {
public record Part(String text, List<FunctionCall> functionCall) {

public static Part ofText(String text) {
return new Part(text, null);
Expand All @@ -26,7 +26,7 @@ public record Part(String text) {
}
}

public record Tool(List<FunctionDeclaration> functionDeclarations) {
public record Tool(FunctionDeclaration functionDeclarations) {

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ public record Content(List<Part> parts) {

}

public record Part(String text, FunctionCall functionCall) {
public record Part(String text, List<FunctionCall> functionCall) {

}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package io.quarkiverse.langchain4j.vertexai.runtime.gemini;

import java.util.ArrayList;
import java.util.List;

import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.model.output.TokenUsage;

final class GenerateContentResponseHandler {
Expand Down Expand Up @@ -42,4 +44,17 @@ static TokenUsage getTokenUsage(GenerateContentResponse.UsageMetadata usageMetad
usageMetadata.candidatesTokenCount(),
usageMetadata.totalTokenCount());
}

static List<ToolExecutionRequest> getToolExecutionRequests(GenerateContentResponse response) {
List<GenerateContentResponse.Candidate.Part> parts = response.candidates().get(0).content().parts();
List<ToolExecutionRequest> toolExecutionRequests = new ArrayList<>();
for (GenerateContentResponse.Candidate.Part part : parts) {
List<FunctionCall> functionCalls = part.functionCall();
if (functionCalls == null || functionCalls.isEmpty()) {
continue;
}
toolExecutionRequests.addAll(functionCalls.stream().map(FunctionCall::toToolExecutionRequest).toList());
}
return toolExecutionRequests;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ static String map(ChatMessageType type) {
return switch (type) {
case USER -> "user";
case AI -> "model";
case TOOL_EXECUTION_RESULT -> null;
case TOOL_EXECUTION_RESULT -> "tool";
default -> throw new IllegalArgumentException(type + " is not allowed.");
};
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package io.quarkiverse.langchain4j.vertexai.runtime.gemini;

import static dev.langchain4j.data.message.AiMessage.aiMessage;

import java.net.URI;
import java.net.URISyntaxException;
import java.time.Duration;
Expand All @@ -10,6 +12,8 @@

import org.jboss.resteasy.reactive.client.api.LoggingScope;

import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
Expand Down Expand Up @@ -57,15 +61,47 @@ private VertexAiGeminiChatLanguageModel(Builder builder) {
}

@Override
public Response<AiMessage> generate(List<ChatMessage> messages) {
GenerateContentRequest request = ContentMapper.map(messages, Collections.emptyList(), generationConfig);
public dev.langchain4j.model.chat.response.ChatResponse chat(dev.langchain4j.model.chat.request.ChatRequest chatRequest) {
GenerateContentRequest request = ContentMapper.map(chatRequest.messages(), chatRequest.toolSpecifications(),
generationConfig);

GenerateContentResponse response = restApi.generateContent(request, apiMetadata);

String text = GenerateContentResponseHandler.getText(response);
List<ToolExecutionRequest> toolExecutionRequests = GenerateContentResponseHandler.getToolExecutionRequests(response);
AiMessage aiMessage = toolExecutionRequests == null || toolExecutionRequests.isEmpty()
? aiMessage(text)
: aiMessage(text, toolExecutionRequests);
return dev.langchain4j.model.chat.response.ChatResponse.builder()
.aiMessage(aiMessage)
.tokenUsage(GenerateContentResponseHandler.getTokenUsage(response.usageMetadata()))
.finishReason(FinishReasonMapper.map(GenerateContentResponseHandler.getFinishReason(response)))
.build();
}

@Override
public Response<AiMessage> generate(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
var chatResponse = chat(dev.langchain4j.model.chat.request.ChatRequest.builder()
.messages(messages)
.toolSpecifications(toolSpecifications)
.build());

return Response.from(
AiMessage.from(GenerateContentResponseHandler.getText(response)),
GenerateContentResponseHandler.getTokenUsage(response.usageMetadata()),
FinishReasonMapper.map(GenerateContentResponseHandler.getFinishReason(response)));
chatResponse.aiMessage(),
chatResponse.tokenUsage(),
chatResponse.finishReason());

}

@Override
public Response<AiMessage> generate(List<ChatMessage> messages) {
return generate(messages, Collections.emptyList());
}

@Override
public Response<AiMessage> generate(List<ChatMessage> messages, ToolSpecification toolSpecification) {
return generate(messages,
toolSpecification != null ? Collections.singletonList(toolSpecification) : Collections.emptyList());
}

public static Builder builder() {
Expand Down

0 comments on commit c7edcc5

Please sign in to comment.