Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Output guardrails should support structured output #1201

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.request.ChatRequest;
import dev.langchain4j.model.chat.request.ResponseFormat;
import dev.langchain4j.model.chat.request.json.JsonSchema;
import dev.langchain4j.model.chat.response.ChatResponse;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.input.structured.StructuredPrompt;
Expand Down Expand Up @@ -337,32 +337,7 @@ private List<ChatMessage> messagesToSend(UserMessage augmentedUserMessage,

log.debug("Attempting to obtain AI response");

Optional<JsonSchema> jsonSchema = Optional.empty();
if (supportsJsonSchema) {
jsonSchema = methodCreateInfo.getResponseSchemaInfo().structuredOutputSchema();
}

Response<AiMessage> response;
if (jsonSchema.isPresent()) {
ChatRequest chatRequest = ChatRequest.builder()
.messages(messagesToSend)
.toolSpecifications(toolSpecifications)
.responseFormat(ResponseFormat.builder()
.type(JSON)
.jsonSchema(jsonSchema.get())
.build())
.build();

ChatResponse chatResponse = context.chatModel.chat(chatRequest);
response = new Response<>(
chatResponse.aiMessage(),
chatResponse.tokenUsage(),
chatResponse.finishReason());
} else {
response = toolSpecifications == null
? context.chatModel.generate(messagesToSend)
: context.chatModel.generate(messagesToSend, toolSpecifications);
}
var response = executeRequest(context, methodCreateInfo, messagesToSend, toolSpecifications);

log.debug("AI response obtained");
if (audit != null) {
Expand Down Expand Up @@ -450,6 +425,46 @@ private List<ChatMessage> messagesToSend(UserMessage augmentedUserMessage,
methodCreateInfo, responseAugmenterParam);
}

private static Response<AiMessage> executeRequest(JsonSchema jsonSchema, List<ChatMessage> messagesToSend,
ChatLanguageModel chatModel, List<ToolSpecification> toolSpecifications) {
var chatRequest = ChatRequest.builder()
.messages(messagesToSend)
.toolSpecifications(toolSpecifications)
.responseFormat(
ResponseFormat.builder()
.type(JSON)
.jsonSchema(jsonSchema)
.build())
.build();

var response = chatModel.chat(chatRequest);

return new Response<>(
response.aiMessage(),
response.tokenUsage(),
response.finishReason());
}

private static Response<AiMessage> executeRequest(List<ChatMessage> messagesToSend, ChatLanguageModel chatModel,
List<ToolSpecification> toolSpecifications) {
return (toolSpecifications == null) ? chatModel.generate(messagesToSend)
: chatModel.generate(messagesToSend, toolSpecifications);
}

static Response<AiMessage> executeRequest(AiServiceMethodCreateInfo methodCreateInfo, List<ChatMessage> messagesToSend,
ChatLanguageModel chatModel, List<ToolSpecification> toolSpecifications) {
var jsonSchema = supportsJsonSchema(chatModel) ? methodCreateInfo.getResponseSchemaInfo().structuredOutputSchema()
: Optional.<JsonSchema> empty();

return jsonSchema.isPresent() ? executeRequest(jsonSchema.get(), messagesToSend, chatModel, toolSpecifications)
: executeRequest(messagesToSend, chatModel, toolSpecifications);
}

static Response<AiMessage> executeRequest(QuarkusAiServiceContext context, AiServiceMethodCreateInfo methodCreateInfo,
List<ChatMessage> messagesToSend, List<ToolSpecification> toolSpecifications) {
return executeRequest(methodCreateInfo, messagesToSend, context.chatModel, toolSpecifications);
}

private static Object doImplementGenerateImage(AiServiceMethodCreateInfo methodCreateInfo, QuarkusAiServiceContext context,
Audit audit, Optional<SystemMessage> systemMessage, UserMessage userMessage,
Object memoryId, Type returnType, Map<String, Object> templateVariables) {
Expand Down Expand Up @@ -547,9 +562,12 @@ private static List<ChatMessage> createMessagesToSendForNoMemory(Optional<System
return result;
}

private static boolean supportsJsonSchema(ChatLanguageModel chatModel) {
return (chatModel != null) && chatModel.supportedCapabilities().contains(RESPONSE_FORMAT_JSON_SCHEMA);
}

private static boolean supportsJsonSchema(AiServiceContext context) {
return context.chatModel != null
&& context.chatModel.supportedCapabilities().contains(RESPONSE_FORMAT_JSON_SCHEMA);
return supportsJsonSchema(context.chatModel);
}

private static Future<Moderation> triggerModerationIfNeeded(AiServiceContext context,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,24 +81,18 @@ public static Response<AiMessage> invokeOutputGuardrails(AiServiceMethodCreateIn
if (!result.isSuccess()) {
if (!result.isRetry()) {
throw new GuardrailException(result.toString(), result.getFirstFailureException());
} else if (result.getReprompt() != null) {
// Retry with re-prompting
chatMemory.add(userMessage(result.getReprompt()));
if (toolSpecifications == null) {
response = chatModel.generate(chatMemory.messages());
} else {
response = chatModel.generate(chatMemory.messages(), toolSpecifications);
}
chatMemory.add(response.content());
} else {
// Retry without re-prompting
if (toolSpecifications == null) {
response = chatModel.generate(chatMemory.messages());
} else {
response = chatModel.generate(chatMemory.messages(), toolSpecifications);
// Retry
if (result.getReprompt() != null) {
// Retry with reprompting
chatMemory.add(userMessage(result.getReprompt()));
}

response = AiServiceMethodImplementationSupport.executeRequest(methodCreateInfo, chatMemory.messages(),
chatModel, toolSpecifications);
chatMemory.add(response.content());
}

attempt++;
output = new OutputGuardrailParams(response.content(), output.memory(),
output.augmentationResult(), output.userMessageTemplate(), output.variables());
Expand Down
Loading