diff --git a/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/ContentMapper.java b/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/ContentMapper.java index 7890a4b7c..70a1209fd 100644 --- a/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/ContentMapper.java +++ b/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/ContentMapper.java @@ -1,15 +1,25 @@ package io.quarkiverse.langchain4j.vertexai.runtime.gemini; import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; import java.util.List; +import java.util.Map; +import com.fasterxml.jackson.core.JsonProcessingException; + +import dev.langchain4j.agent.tool.ToolExecutionRequest; import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.ChatMessage; import dev.langchain4j.data.message.Content; import dev.langchain4j.data.message.SystemMessage; import dev.langchain4j.data.message.TextContent; +import dev.langchain4j.data.message.ToolExecutionResultMessage; import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.model.chat.request.json.JsonObjectSchema; +import dev.langchain4j.model.chat.request.json.JsonSchemaElementHelper; +import io.quarkiverse.langchain4j.QuarkusJsonCodecFactory; final class ContentMapper { @@ -37,11 +47,31 @@ static GenerateContentRequest map(List messages, List toolExecutionRequests = am.toolExecutionRequests(); + List toolCalls = new ArrayList<>(toolExecutionRequests.size()); + for (ToolExecutionRequest toolExecutionRequest : toolExecutionRequests) { + String argumentsStr = toolExecutionRequest.arguments(); + String name = toolExecutionRequest.name(); + // TODO: we need to update LangChain4j to make ToolExecutionRequest use a map instead of a String + Map arguments = QuarkusJsonCodecFactory.ObjectMapperHolder.MAPPER.readValue( + argumentsStr, + Map.class); + toolCalls.add(new FunctionCall(name, arguments)); + } + contents.add(new GenerateContentRequest.Content(role, + List.of(new GenerateContentRequest.Content.Part(am.text(), toolCalls)))); + } else { + contents.add(new GenerateContentRequest.Content(role, + List.of(GenerateContentRequest.Content.Part.ofText(am.text())))); + } + } catch (JsonProcessingException e) { + throw new IllegalStateException("Unable to perform conversion of tool response", e); } + } else if (message instanceof ToolExecutionResultMessage) { contents.add(new GenerateContentRequest.Content(role, - List.of(GenerateContentRequest.Content.Part.ofText(am.text())))); + List.of(GenerateContentRequest.Content.Part.ofText(message.text())))); } else { throw new IllegalArgumentException( "The Gemini integration currently does not support " + message.type() + " messages"); @@ -49,18 +79,38 @@ static GenerateContentRequest map(List messages, List tools; - if (toolSpecifications == null || toolSpecifications.isEmpty()) { - tools = null; - } else { - tools = new ArrayList<>(toolSpecifications.size()); - for (GenerateContentRequest.Tool tool : tools) { - // TODO: implement - } - } - return new GenerateContentRequest(contents, - !systemPrompts.isEmpty() ? GenerateContentRequest.SystemInstruction.ofContent(systemPrompts) : null, tools, + !systemPrompts.isEmpty() ? GenerateContentRequest.SystemInstruction.ofContent(systemPrompts) : null, + toTools(toolSpecifications), generationConfig); } + + static List toTools(Collection toolSpecifications) { + if (toolSpecifications == null) { + return null; + } + if (toolSpecifications.isEmpty()) { + return Collections.emptyList(); + } + List result = new ArrayList<>(toolSpecifications.size()); + for (ToolSpecification toolSpecification : toolSpecifications) { + result.add(toTool(toolSpecification)); + } + return result; + } + + private static GenerateContentRequest.Tool toTool(ToolSpecification toolSpecification) { + FunctionDeclaration.Parameters functionParameters = toFunctionParameters(toolSpecification.parameters()); + + return new GenerateContentRequest.Tool( + new FunctionDeclaration(toolSpecification.name(), toolSpecification.description(), functionParameters)); + } + + private static FunctionDeclaration.Parameters toFunctionParameters(JsonObjectSchema parameters) { + if (parameters == null) { + return FunctionDeclaration.Parameters.empty(); + } + return FunctionDeclaration.Parameters.objectType(JsonSchemaElementHelper.toMap(parameters.properties()), + parameters.required()); + } } diff --git a/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/FunctionCall.java b/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/FunctionCall.java index a22f3fc90..aa047b216 100644 --- a/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/FunctionCall.java +++ b/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/FunctionCall.java @@ -2,6 +2,21 @@ import java.util.Map; -public record FunctionCall(String name, Map args) { +import com.fasterxml.jackson.core.JsonProcessingException; +import dev.langchain4j.agent.tool.ToolExecutionRequest; +import io.quarkiverse.langchain4j.QuarkusJsonCodecFactory; + +public record FunctionCall(String name, Map arguments) { + public ToolExecutionRequest toToolExecutionRequest() { + try { + return ToolExecutionRequest.builder() + .name(name()) + .arguments(QuarkusJsonCodecFactory.ObjectMapperHolder.MAPPER + .writeValueAsString(arguments())) + .build(); + } catch (JsonProcessingException e) { + throw new RuntimeException("Unable to parse tool call response", e); + } + } } diff --git a/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/FunctionDeclaration.java b/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/FunctionDeclaration.java index 2c96daa83..be478f493 100644 --- a/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/FunctionDeclaration.java +++ b/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/FunctionDeclaration.java @@ -1,5 +1,6 @@ package io.quarkiverse.langchain4j.vertexai.runtime.gemini; +import java.util.Collections; import java.util.List; import java.util.Map; @@ -7,5 +8,14 @@ public record FunctionDeclaration(String name, String description, Parameters pa public record Parameters(String type, Map> properties, List required) { + private static final String OBJECT_TYPE = "object"; + + public static Parameters objectType(Map> properties, List required) { + return new Parameters(OBJECT_TYPE, properties, required); + } + + public static Parameters empty() { + return new Parameters(OBJECT_TYPE, Collections.emptyMap(), Collections.emptyList()); + } } } diff --git a/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/GenerateContentRequest.java b/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/GenerateContentRequest.java index 7809c666a..f30f0cde7 100644 --- a/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/GenerateContentRequest.java +++ b/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/GenerateContentRequest.java @@ -7,7 +7,7 @@ public record GenerateContentRequest(List contents, SystemInstruction s public record Content(String role, List parts) { - public record Part(String text, FunctionCall functionCall) { + public record Part(String text, List functionCall) { public static Part ofText(String text) { return new Part(text, null); @@ -26,7 +26,7 @@ public record Part(String text) { } } - public record Tool(List functionDeclarations) { + public record Tool(FunctionDeclaration functionDeclarations) { } diff --git a/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/GenerateContentResponse.java b/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/GenerateContentResponse.java index 461554696..f9ae38044 100644 --- a/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/GenerateContentResponse.java +++ b/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/GenerateContentResponse.java @@ -10,7 +10,7 @@ public record Content(List parts) { } - public record Part(String text, FunctionCall functionCall) { + public record Part(String text, List functionCall) { } diff --git a/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/GenerateContentResponseHandler.java b/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/GenerateContentResponseHandler.java index 32b5fc503..b0cbf2456 100644 --- a/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/GenerateContentResponseHandler.java +++ b/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/GenerateContentResponseHandler.java @@ -1,7 +1,9 @@ package io.quarkiverse.langchain4j.vertexai.runtime.gemini; +import java.util.ArrayList; import java.util.List; +import dev.langchain4j.agent.tool.ToolExecutionRequest; import dev.langchain4j.model.output.TokenUsage; final class GenerateContentResponseHandler { @@ -42,4 +44,17 @@ static TokenUsage getTokenUsage(GenerateContentResponse.UsageMetadata usageMetad usageMetadata.candidatesTokenCount(), usageMetadata.totalTokenCount()); } + + static List getToolExecutionRequests(GenerateContentResponse response) { + List parts = response.candidates().get(0).content().parts(); + List toolExecutionRequests = new ArrayList<>(); + for (GenerateContentResponse.Candidate.Part part : parts) { + List functionCalls = part.functionCall(); + if (functionCalls == null || functionCalls.isEmpty()) { + continue; + } + toolExecutionRequests.addAll(functionCalls.stream().map(FunctionCall::toToolExecutionRequest).toList()); + } + return toolExecutionRequests; + } } diff --git a/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/RoleMapper.java b/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/RoleMapper.java index cc623d296..3a82d006c 100644 --- a/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/RoleMapper.java +++ b/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/RoleMapper.java @@ -11,7 +11,7 @@ static String map(ChatMessageType type) { return switch (type) { case USER -> "user"; case AI -> "model"; - case TOOL_EXECUTION_RESULT -> null; + case TOOL_EXECUTION_RESULT -> "tool"; default -> throw new IllegalArgumentException(type + " is not allowed."); }; } diff --git a/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/VertexAiGeminiChatLanguageModel.java b/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/VertexAiGeminiChatLanguageModel.java index 7ad42cbf4..82ceb2dfe 100644 --- a/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/VertexAiGeminiChatLanguageModel.java +++ b/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/VertexAiGeminiChatLanguageModel.java @@ -1,5 +1,7 @@ package io.quarkiverse.langchain4j.vertexai.runtime.gemini; +import static dev.langchain4j.data.message.AiMessage.aiMessage; + import java.net.URI; import java.net.URISyntaxException; import java.time.Duration; @@ -10,6 +12,8 @@ import org.jboss.resteasy.reactive.client.api.LoggingScope; +import dev.langchain4j.agent.tool.ToolExecutionRequest; +import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.ChatMessage; import dev.langchain4j.model.chat.ChatLanguageModel; @@ -57,15 +61,47 @@ private VertexAiGeminiChatLanguageModel(Builder builder) { } @Override - public Response generate(List messages) { - GenerateContentRequest request = ContentMapper.map(messages, Collections.emptyList(), generationConfig); + public dev.langchain4j.model.chat.response.ChatResponse chat(dev.langchain4j.model.chat.request.ChatRequest chatRequest) { + GenerateContentRequest request = ContentMapper.map(chatRequest.messages(), chatRequest.toolSpecifications(), + generationConfig); GenerateContentResponse response = restApi.generateContent(request, apiMetadata); + String text = GenerateContentResponseHandler.getText(response); + List toolExecutionRequests = GenerateContentResponseHandler.getToolExecutionRequests(response); + AiMessage aiMessage = toolExecutionRequests == null || toolExecutionRequests.isEmpty() + ? aiMessage(text) + : aiMessage(text, toolExecutionRequests); + return dev.langchain4j.model.chat.response.ChatResponse.builder() + .aiMessage(aiMessage) + .tokenUsage(GenerateContentResponseHandler.getTokenUsage(response.usageMetadata())) + .finishReason(FinishReasonMapper.map(GenerateContentResponseHandler.getFinishReason(response))) + .build(); + } + + @Override + public Response generate(List messages, List toolSpecifications) { + var chatResponse = chat(dev.langchain4j.model.chat.request.ChatRequest.builder() + .messages(messages) + .toolSpecifications(toolSpecifications) + .build()); + return Response.from( - AiMessage.from(GenerateContentResponseHandler.getText(response)), - GenerateContentResponseHandler.getTokenUsage(response.usageMetadata()), - FinishReasonMapper.map(GenerateContentResponseHandler.getFinishReason(response))); + chatResponse.aiMessage(), + chatResponse.tokenUsage(), + chatResponse.finishReason()); + + } + + @Override + public Response generate(List messages) { + return generate(messages, Collections.emptyList()); + } + + @Override + public Response generate(List messages, ToolSpecification toolSpecification) { + return generate(messages, + toolSpecification != null ? Collections.singletonList(toolSpecification) : Collections.emptyList()); } public static Builder builder() {