From 5987cb81f406163784d3e508c3fc8bdd622853d2 Mon Sep 17 00:00:00 2001 From: Sergey Beryozkin Date: Tue, 17 Dec 2024 17:09:28 +0000 Subject: [PATCH] Support Tools for Vertex AI Gemini --- integration-tests/pom.xml | 1 + integration-tests/vertex-ai-gemini/pom.xml | 126 ++++++++++++++++++ .../AssistantWithToolsResource.java | 45 +++++++ .../gemini/aiservices/DummyAuthProvider.java | 15 +++ .../gemini/aiservices/GeminiResource.java | 68 ++++++++++ .../src/main/resources/application.properties | 5 + .../AssistantResourceWithToolsTest.java | 31 +++++ .../runtime/gemini/ContentMapper.java | 78 +++++++++-- .../vertexai/runtime/gemini/FunctionCall.java | 17 ++- .../runtime/gemini/FunctionDeclaration.java | 10 ++ .../gemini/GenerateContentRequest.java | 4 +- .../gemini/GenerateContentResponse.java | 2 +- .../GenerateContentResponseHandler.java | 15 +++ .../vertexai/runtime/gemini/RoleMapper.java | 2 +- .../VertexAiGeminiChatLanguageModel.java | 46 ++++++- 15 files changed, 441 insertions(+), 24 deletions(-) create mode 100644 integration-tests/vertex-ai-gemini/pom.xml create mode 100644 integration-tests/vertex-ai-gemini/src/main/java/org/acme/example/gemini/aiservices/AssistantWithToolsResource.java create mode 100644 integration-tests/vertex-ai-gemini/src/main/java/org/acme/example/gemini/aiservices/DummyAuthProvider.java create mode 100644 integration-tests/vertex-ai-gemini/src/main/java/org/acme/example/gemini/aiservices/GeminiResource.java create mode 100644 integration-tests/vertex-ai-gemini/src/main/resources/application.properties create mode 100644 integration-tests/vertex-ai-gemini/src/test/java/org/acme/example/gemini/aiservices/AssistantResourceWithToolsTest.java diff --git a/integration-tests/pom.xml b/integration-tests/pom.xml index 75d6836de..9a6741908 100644 --- a/integration-tests/pom.xml +++ b/integration-tests/pom.xml @@ -19,6 +19,7 @@ azure-openai multiple-providers mistralai + vertex-ai-gemini devui devui-multiple-embedding-models in-process-embedding-models diff --git a/integration-tests/vertex-ai-gemini/pom.xml b/integration-tests/vertex-ai-gemini/pom.xml new file mode 100644 index 000000000..ed33a6634 --- /dev/null +++ b/integration-tests/vertex-ai-gemini/pom.xml @@ -0,0 +1,126 @@ + + + 4.0.0 + + io.quarkiverse.langchain4j + quarkus-langchain4j-integration-tests-parent + 999-SNAPSHOT + + quarkus-langchain4j-integration-test-vertex-ai-gemini + Quarkus LangChain4j - Integration Tests - Vertex AI Gemini + + true + + + + io.quarkus + quarkus-rest-jackson + + + io.quarkiverse.langchain4j + quarkus-langchain4j-vertex-ai-gemini + ${project.version} + + + io.quarkus + quarkus-micrometer + + + io.quarkus + quarkus-smallrye-fault-tolerance + + + io.quarkus + quarkus-junit5 + test + + + io.rest-assured + rest-assured + test + + + org.assertj + assertj-core + ${assertj.version} + test + + + io.quarkus + quarkus-devtools-testing + test + + + + + io.quarkiverse.langchain4j + quarkus-langchain4j-vertex-ai-gemini-deployment + ${project.version} + pom + test + + + * + * + + + + + + + + io.quarkus + quarkus-maven-plugin + + + + build + + + + + + maven-failsafe-plugin + + + + integration-test + verify + + + + ${project.build.directory}/${project.build.finalName}-runner + org.jboss.logmanager.LogManager + ${maven.home} + + + + + + + + + + native-image + + + native + + + + + + maven-surefire-plugin + + ${native.surefire.skip} + + + + + + false + native + + + + diff --git a/integration-tests/vertex-ai-gemini/src/main/java/org/acme/example/gemini/aiservices/AssistantWithToolsResource.java b/integration-tests/vertex-ai-gemini/src/main/java/org/acme/example/gemini/aiservices/AssistantWithToolsResource.java new file mode 100644 index 000000000..e74aa8f7a --- /dev/null +++ b/integration-tests/vertex-ai-gemini/src/main/java/org/acme/example/gemini/aiservices/AssistantWithToolsResource.java @@ -0,0 +1,45 @@ +package org.acme.example.gemini.aiservices; + +import jakarta.inject.Inject; +import jakarta.inject.Singleton; +import jakarta.ws.rs.GET; +import jakarta.ws.rs.Path; + +import org.jboss.resteasy.reactive.RestQuery; + +import dev.langchain4j.agent.tool.Tool; +import io.quarkiverse.langchain4j.RegisterAiService; + +@Path("assistant-with-tool") +public class AssistantWithToolsResource { + + private final Assistant assistant; + + @Inject + AddContentTool tool; + + public AssistantWithToolsResource(Assistant assistant) { + this.assistant = assistant; + } + + @GET + public String get(@RestQuery String message) { + return assistant.chat(message) + "; " + tool.tool1Content; + } + + @RegisterAiService(tools = AddContentTool.class) + public interface Assistant { + String chat(String userMessage); + } + + @Singleton + public static class AddContentTool { + + volatile String tool1Content; + + @Tool("Add content") + void addContent(String content) { + this.tool1Content = "Tool1: " + content; + } + } +} diff --git a/integration-tests/vertex-ai-gemini/src/main/java/org/acme/example/gemini/aiservices/DummyAuthProvider.java b/integration-tests/vertex-ai-gemini/src/main/java/org/acme/example/gemini/aiservices/DummyAuthProvider.java new file mode 100644 index 000000000..d5824a2a5 --- /dev/null +++ b/integration-tests/vertex-ai-gemini/src/main/java/org/acme/example/gemini/aiservices/DummyAuthProvider.java @@ -0,0 +1,15 @@ +package org.acme.example.gemini.aiservices; + +import jakarta.inject.Singleton; + +import io.quarkiverse.langchain4j.auth.ModelAuthProvider; + +@Singleton +public class DummyAuthProvider implements ModelAuthProvider { + + @Override + public String getAuthorization(Input input) { + return "Bearer token"; + } + +} diff --git a/integration-tests/vertex-ai-gemini/src/main/java/org/acme/example/gemini/aiservices/GeminiResource.java b/integration-tests/vertex-ai-gemini/src/main/java/org/acme/example/gemini/aiservices/GeminiResource.java new file mode 100644 index 000000000..63790122e --- /dev/null +++ b/integration-tests/vertex-ai-gemini/src/main/java/org/acme/example/gemini/aiservices/GeminiResource.java @@ -0,0 +1,68 @@ +package org.acme.example.gemini.aiservices; + +import jakarta.ws.rs.POST; +import jakarta.ws.rs.Path; +import jakarta.ws.rs.Produces; + +@Path("gemini/") +public class GeminiResource { + + @POST + @Path("v1/projects/my_google_project_id/locations/west-europe/publishers/google/models/gemini-pro:generateContent") + @Produces("application/json") + public String get() { + return """ + { + "candidates": [ + { + "content": { + "role": "model", + "parts": [ + { + "text": "Nice to meet you" + } + ] + }, + "finishReason": "STOP", + "safetyRatings": [ + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE", + "probabilityScore": 0.044847902, + "severity": "HARM_SEVERITY_NEGLIGIBLE", + "severityScore": 0.05592617 + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE", + "probabilityScore": 0.18877223, + "severity": "HARM_SEVERITY_NEGLIGIBLE", + "severityScore": 0.027324531 + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE", + "probabilityScore": 0.15278918, + "severity": "HARM_SEVERITY_NEGLIGIBLE", + "severityScore": 0.045437217 + }, + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE", + "probabilityScore": 0.15869519, + "severity": "HARM_SEVERITY_NEGLIGIBLE", + "severityScore": 0.036838707 + } + ] + } + ], + "usageMetadata": { + "promptTokenCount": 11, + "candidatesTokenCount": 37, + "totalTokenCount": 48 + } + } + """; + } + +} diff --git a/integration-tests/vertex-ai-gemini/src/main/resources/application.properties b/integration-tests/vertex-ai-gemini/src/main/resources/application.properties new file mode 100644 index 000000000..ecd95455c --- /dev/null +++ b/integration-tests/vertex-ai-gemini/src/main/resources/application.properties @@ -0,0 +1,5 @@ +quarkus.langchain4j.vertexai.gemini.base-url=http://localhost:8081/gemini +quarkus.langchain4j.vertexai.gemini.location=west-europe +quarkus.langchain4j.vertexai.gemini.project-id=my_google_project_id +quarkus.langchain4j.vertexai.gemini.log-requests=true +quarkus.langchain4j.vertexai.gemini.log-responses=true diff --git a/integration-tests/vertex-ai-gemini/src/test/java/org/acme/example/gemini/aiservices/AssistantResourceWithToolsTest.java b/integration-tests/vertex-ai-gemini/src/test/java/org/acme/example/gemini/aiservices/AssistantResourceWithToolsTest.java new file mode 100644 index 000000000..4e8cd8508 --- /dev/null +++ b/integration-tests/vertex-ai-gemini/src/test/java/org/acme/example/gemini/aiservices/AssistantResourceWithToolsTest.java @@ -0,0 +1,31 @@ +package org.acme.example.gemini.aiservices; + +import static io.restassured.RestAssured.given; +import static org.hamcrest.Matchers.equalTo; + +import java.net.URL; + +import org.junit.jupiter.api.Test; + +import io.quarkus.test.common.http.TestHTTPEndpoint; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.test.junit.QuarkusTest; + +@QuarkusTest +public class AssistantResourceWithToolsTest { + + @TestHTTPEndpoint(AssistantWithToolsResource.class) + @TestHTTPResource + URL url; + + @Test + public void get() { + given() + .baseUri(url.toString()) + .queryParam("message", "This is a test") + .get() + .then() + .statusCode(200) + .body(equalTo("Nice to meet you; tool1: Nice to meet you")); + } +} diff --git a/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/ContentMapper.java b/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/ContentMapper.java index 7890a4b7c..70a1209fd 100644 --- a/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/ContentMapper.java +++ b/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/ContentMapper.java @@ -1,15 +1,25 @@ package io.quarkiverse.langchain4j.vertexai.runtime.gemini; import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; import java.util.List; +import java.util.Map; +import com.fasterxml.jackson.core.JsonProcessingException; + +import dev.langchain4j.agent.tool.ToolExecutionRequest; import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.ChatMessage; import dev.langchain4j.data.message.Content; import dev.langchain4j.data.message.SystemMessage; import dev.langchain4j.data.message.TextContent; +import dev.langchain4j.data.message.ToolExecutionResultMessage; import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.model.chat.request.json.JsonObjectSchema; +import dev.langchain4j.model.chat.request.json.JsonSchemaElementHelper; +import io.quarkiverse.langchain4j.QuarkusJsonCodecFactory; final class ContentMapper { @@ -37,11 +47,31 @@ static GenerateContentRequest map(List messages, List toolExecutionRequests = am.toolExecutionRequests(); + List toolCalls = new ArrayList<>(toolExecutionRequests.size()); + for (ToolExecutionRequest toolExecutionRequest : toolExecutionRequests) { + String argumentsStr = toolExecutionRequest.arguments(); + String name = toolExecutionRequest.name(); + // TODO: we need to update LangChain4j to make ToolExecutionRequest use a map instead of a String + Map arguments = QuarkusJsonCodecFactory.ObjectMapperHolder.MAPPER.readValue( + argumentsStr, + Map.class); + toolCalls.add(new FunctionCall(name, arguments)); + } + contents.add(new GenerateContentRequest.Content(role, + List.of(new GenerateContentRequest.Content.Part(am.text(), toolCalls)))); + } else { + contents.add(new GenerateContentRequest.Content(role, + List.of(GenerateContentRequest.Content.Part.ofText(am.text())))); + } + } catch (JsonProcessingException e) { + throw new IllegalStateException("Unable to perform conversion of tool response", e); } + } else if (message instanceof ToolExecutionResultMessage) { contents.add(new GenerateContentRequest.Content(role, - List.of(GenerateContentRequest.Content.Part.ofText(am.text())))); + List.of(GenerateContentRequest.Content.Part.ofText(message.text())))); } else { throw new IllegalArgumentException( "The Gemini integration currently does not support " + message.type() + " messages"); @@ -49,18 +79,38 @@ static GenerateContentRequest map(List messages, List tools; - if (toolSpecifications == null || toolSpecifications.isEmpty()) { - tools = null; - } else { - tools = new ArrayList<>(toolSpecifications.size()); - for (GenerateContentRequest.Tool tool : tools) { - // TODO: implement - } - } - return new GenerateContentRequest(contents, - !systemPrompts.isEmpty() ? GenerateContentRequest.SystemInstruction.ofContent(systemPrompts) : null, tools, + !systemPrompts.isEmpty() ? GenerateContentRequest.SystemInstruction.ofContent(systemPrompts) : null, + toTools(toolSpecifications), generationConfig); } + + static List toTools(Collection toolSpecifications) { + if (toolSpecifications == null) { + return null; + } + if (toolSpecifications.isEmpty()) { + return Collections.emptyList(); + } + List result = new ArrayList<>(toolSpecifications.size()); + for (ToolSpecification toolSpecification : toolSpecifications) { + result.add(toTool(toolSpecification)); + } + return result; + } + + private static GenerateContentRequest.Tool toTool(ToolSpecification toolSpecification) { + FunctionDeclaration.Parameters functionParameters = toFunctionParameters(toolSpecification.parameters()); + + return new GenerateContentRequest.Tool( + new FunctionDeclaration(toolSpecification.name(), toolSpecification.description(), functionParameters)); + } + + private static FunctionDeclaration.Parameters toFunctionParameters(JsonObjectSchema parameters) { + if (parameters == null) { + return FunctionDeclaration.Parameters.empty(); + } + return FunctionDeclaration.Parameters.objectType(JsonSchemaElementHelper.toMap(parameters.properties()), + parameters.required()); + } } diff --git a/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/FunctionCall.java b/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/FunctionCall.java index a22f3fc90..aa047b216 100644 --- a/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/FunctionCall.java +++ b/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/FunctionCall.java @@ -2,6 +2,21 @@ import java.util.Map; -public record FunctionCall(String name, Map args) { +import com.fasterxml.jackson.core.JsonProcessingException; +import dev.langchain4j.agent.tool.ToolExecutionRequest; +import io.quarkiverse.langchain4j.QuarkusJsonCodecFactory; + +public record FunctionCall(String name, Map arguments) { + public ToolExecutionRequest toToolExecutionRequest() { + try { + return ToolExecutionRequest.builder() + .name(name()) + .arguments(QuarkusJsonCodecFactory.ObjectMapperHolder.MAPPER + .writeValueAsString(arguments())) + .build(); + } catch (JsonProcessingException e) { + throw new RuntimeException("Unable to parse tool call response", e); + } + } } diff --git a/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/FunctionDeclaration.java b/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/FunctionDeclaration.java index 2c96daa83..be478f493 100644 --- a/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/FunctionDeclaration.java +++ b/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/FunctionDeclaration.java @@ -1,5 +1,6 @@ package io.quarkiverse.langchain4j.vertexai.runtime.gemini; +import java.util.Collections; import java.util.List; import java.util.Map; @@ -7,5 +8,14 @@ public record FunctionDeclaration(String name, String description, Parameters pa public record Parameters(String type, Map> properties, List required) { + private static final String OBJECT_TYPE = "object"; + + public static Parameters objectType(Map> properties, List required) { + return new Parameters(OBJECT_TYPE, properties, required); + } + + public static Parameters empty() { + return new Parameters(OBJECT_TYPE, Collections.emptyMap(), Collections.emptyList()); + } } } diff --git a/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/GenerateContentRequest.java b/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/GenerateContentRequest.java index 7809c666a..f30f0cde7 100644 --- a/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/GenerateContentRequest.java +++ b/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/GenerateContentRequest.java @@ -7,7 +7,7 @@ public record GenerateContentRequest(List contents, SystemInstruction s public record Content(String role, List parts) { - public record Part(String text, FunctionCall functionCall) { + public record Part(String text, List functionCall) { public static Part ofText(String text) { return new Part(text, null); @@ -26,7 +26,7 @@ public record Part(String text) { } } - public record Tool(List functionDeclarations) { + public record Tool(FunctionDeclaration functionDeclarations) { } diff --git a/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/GenerateContentResponse.java b/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/GenerateContentResponse.java index 461554696..f9ae38044 100644 --- a/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/GenerateContentResponse.java +++ b/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/GenerateContentResponse.java @@ -10,7 +10,7 @@ public record Content(List parts) { } - public record Part(String text, FunctionCall functionCall) { + public record Part(String text, List functionCall) { } diff --git a/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/GenerateContentResponseHandler.java b/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/GenerateContentResponseHandler.java index 32b5fc503..b0cbf2456 100644 --- a/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/GenerateContentResponseHandler.java +++ b/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/GenerateContentResponseHandler.java @@ -1,7 +1,9 @@ package io.quarkiverse.langchain4j.vertexai.runtime.gemini; +import java.util.ArrayList; import java.util.List; +import dev.langchain4j.agent.tool.ToolExecutionRequest; import dev.langchain4j.model.output.TokenUsage; final class GenerateContentResponseHandler { @@ -42,4 +44,17 @@ static TokenUsage getTokenUsage(GenerateContentResponse.UsageMetadata usageMetad usageMetadata.candidatesTokenCount(), usageMetadata.totalTokenCount()); } + + static List getToolExecutionRequests(GenerateContentResponse response) { + List parts = response.candidates().get(0).content().parts(); + List toolExecutionRequests = new ArrayList<>(); + for (GenerateContentResponse.Candidate.Part part : parts) { + List functionCalls = part.functionCall(); + if (functionCalls == null || functionCalls.isEmpty()) { + continue; + } + toolExecutionRequests.addAll(functionCalls.stream().map(FunctionCall::toToolExecutionRequest).toList()); + } + return toolExecutionRequests; + } } diff --git a/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/RoleMapper.java b/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/RoleMapper.java index cc623d296..3a82d006c 100644 --- a/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/RoleMapper.java +++ b/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/RoleMapper.java @@ -11,7 +11,7 @@ static String map(ChatMessageType type) { return switch (type) { case USER -> "user"; case AI -> "model"; - case TOOL_EXECUTION_RESULT -> null; + case TOOL_EXECUTION_RESULT -> "tool"; default -> throw new IllegalArgumentException(type + " is not allowed."); }; } diff --git a/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/VertexAiGeminiChatLanguageModel.java b/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/VertexAiGeminiChatLanguageModel.java index 7ad42cbf4..82ceb2dfe 100644 --- a/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/VertexAiGeminiChatLanguageModel.java +++ b/model-providers/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/VertexAiGeminiChatLanguageModel.java @@ -1,5 +1,7 @@ package io.quarkiverse.langchain4j.vertexai.runtime.gemini; +import static dev.langchain4j.data.message.AiMessage.aiMessage; + import java.net.URI; import java.net.URISyntaxException; import java.time.Duration; @@ -10,6 +12,8 @@ import org.jboss.resteasy.reactive.client.api.LoggingScope; +import dev.langchain4j.agent.tool.ToolExecutionRequest; +import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.ChatMessage; import dev.langchain4j.model.chat.ChatLanguageModel; @@ -57,15 +61,47 @@ private VertexAiGeminiChatLanguageModel(Builder builder) { } @Override - public Response generate(List messages) { - GenerateContentRequest request = ContentMapper.map(messages, Collections.emptyList(), generationConfig); + public dev.langchain4j.model.chat.response.ChatResponse chat(dev.langchain4j.model.chat.request.ChatRequest chatRequest) { + GenerateContentRequest request = ContentMapper.map(chatRequest.messages(), chatRequest.toolSpecifications(), + generationConfig); GenerateContentResponse response = restApi.generateContent(request, apiMetadata); + String text = GenerateContentResponseHandler.getText(response); + List toolExecutionRequests = GenerateContentResponseHandler.getToolExecutionRequests(response); + AiMessage aiMessage = toolExecutionRequests == null || toolExecutionRequests.isEmpty() + ? aiMessage(text) + : aiMessage(text, toolExecutionRequests); + return dev.langchain4j.model.chat.response.ChatResponse.builder() + .aiMessage(aiMessage) + .tokenUsage(GenerateContentResponseHandler.getTokenUsage(response.usageMetadata())) + .finishReason(FinishReasonMapper.map(GenerateContentResponseHandler.getFinishReason(response))) + .build(); + } + + @Override + public Response generate(List messages, List toolSpecifications) { + var chatResponse = chat(dev.langchain4j.model.chat.request.ChatRequest.builder() + .messages(messages) + .toolSpecifications(toolSpecifications) + .build()); + return Response.from( - AiMessage.from(GenerateContentResponseHandler.getText(response)), - GenerateContentResponseHandler.getTokenUsage(response.usageMetadata()), - FinishReasonMapper.map(GenerateContentResponseHandler.getFinishReason(response))); + chatResponse.aiMessage(), + chatResponse.tokenUsage(), + chatResponse.finishReason()); + + } + + @Override + public Response generate(List messages) { + return generate(messages, Collections.emptyList()); + } + + @Override + public Response generate(List messages, ToolSpecification toolSpecification) { + return generate(messages, + toolSpecification != null ? Collections.singletonList(toolSpecification) : Collections.emptyList()); } public static Builder builder() {