diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ListenersProcessor.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ListenersProcessor.java index 872525073..ae836050b 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ListenersProcessor.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ListenersProcessor.java @@ -2,6 +2,12 @@ import java.util.Optional; +import jakarta.inject.Singleton; + +import org.jboss.jandex.DotName; + +import io.quarkiverse.langchain4j.cost.CostEstimatorResponseListener; +import io.quarkiverse.langchain4j.deployment.config.LangChain4jBuildConfig; import io.quarkiverse.langchain4j.runtime.listeners.MetricsChatModelListener; import io.quarkiverse.langchain4j.runtime.listeners.SpanChatModelListener; import io.quarkus.arc.deployment.AdditionalBeanBuildItem; @@ -14,6 +20,20 @@ public class ListenersProcessor { + @BuildStep + public void costListener( + LangChain4jBuildConfig config, + BuildProducer additionalBeanProducer) { + if (config.costListener()) { + additionalBeanProducer.produce( + AdditionalBeanBuildItem.builder() + .addBeanClass(CostEstimatorResponseListener.class) + .setDefaultScope(DotName.createSimple(Singleton.class)) + .setUnremovable() + .build()); + } + } + @BuildStep public void spanListeners(Capabilities capabilities, Optional metricsCapability, diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/config/LangChain4jBuildConfig.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/config/LangChain4jBuildConfig.java index 1eb4448d1..b0040fec5 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/config/LangChain4jBuildConfig.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/config/LangChain4jBuildConfig.java @@ -43,6 +43,12 @@ public interface LangChain4jBuildConfig { @WithDefault("true") boolean responseSchema(); + /** + * Configuration property to enable or disable generic cost listener + */ + @WithDefault("false") + boolean costListener(); + interface BaseConfig { /** * Chat model diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/cost/CostEstimatorResponseListener.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/cost/CostEstimatorResponseListener.java new file mode 100644 index 000000000..2fb73fb13 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/cost/CostEstimatorResponseListener.java @@ -0,0 +1,55 @@ +package io.quarkiverse.langchain4j.cost; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; + +import jakarta.inject.Inject; + +import dev.langchain4j.model.output.TokenUsage; +import io.quarkiverse.langchain4j.response.ResponseListener; +import io.quarkiverse.langchain4j.response.ResponseRecord; +import io.quarkus.arc.All; +import io.smallrye.common.annotation.Experimental; + +/** + * Allows for user code to provide a custom strategy for estimating the cost of API calls + */ +@Experimental("This feature is experimental and the API is subject to change") +public class CostEstimatorResponseListener implements ResponseListener { + + private final CostEstimatorService service; + private final List listeners; + + @Inject + public CostEstimatorResponseListener(CostEstimatorService service, @All List listeners) { + this.service = service; + this.listeners = new ArrayList<>(listeners); + this.listeners.sort(Comparator.comparingInt(CostListener::order)); + } + + @Override + public void onResponse(ResponseRecord rr) { + String model = rr.model(); + TokenUsage tokenUsage = rr.tokenUsage(); + CostEstimator.CostContext context = new MyCostContext(tokenUsage, model); + Cost cost = service.estimate(context); + if (cost != null) { + for (CostListener cl : listeners) { + cl.handleCost(model, tokenUsage, cost); + } + } + } + + private record MyCostContext(TokenUsage tokenUsage, String model) implements CostEstimator.CostContext { + @Override + public Integer inputTokens() { + return tokenUsage().inputTokenCount(); + } + + @Override + public Integer outputTokens() { + return tokenUsage().outputTokenCount(); + } + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/cost/CostEstimatorService.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/cost/CostEstimatorService.java index 7c499fdd6..cb36a5ed3 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/cost/CostEstimatorService.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/cost/CostEstimatorService.java @@ -28,10 +28,13 @@ public CostEstimatorService(@All List costEstimators) { public Cost estimate(ChatModelResponseContext response) { TokenUsage tokenUsage = response.response().tokenUsage(); CostEstimator.CostContext costContext = new MyCostContext(tokenUsage, response); + return estimate(costContext); + } + public Cost estimate(CostEstimator.CostContext context) { for (CostEstimator costEstimator : costEstimators) { - if (costEstimator.supports(costContext)) { - CostEstimator.CostResult costResult = costEstimator.estimate(costContext); + if (costEstimator.supports(context)) { + CostEstimator.CostResult costResult = costEstimator.estimate(context); if (costResult != null) { BigDecimal totalCost = costResult.inputTokensCost().add(costResult.outputTokensCost()); return new Cost(totalCost, costResult.currency()); diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/cost/CostListener.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/cost/CostListener.java new file mode 100644 index 000000000..bd21c6b90 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/cost/CostListener.java @@ -0,0 +1,14 @@ +package io.quarkiverse.langchain4j.cost; + +import dev.langchain4j.model.output.TokenUsage; + +/** + * Allows for user code to handle estimate cost; e.g. some simple accounting + */ +public interface CostListener { + void handleCost(String model, TokenUsage tokenUsage, Cost cost); + + default int order() { + return 0; + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/response/ResponseInterceptor.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/response/ResponseInterceptor.java new file mode 100644 index 000000000..d8cfcbda2 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/response/ResponseInterceptor.java @@ -0,0 +1,47 @@ +package io.quarkiverse.langchain4j.response; + +import java.util.Map; + +import jakarta.annotation.Priority; +import jakarta.interceptor.AroundInvoke; +import jakarta.interceptor.Interceptor; +import jakarta.interceptor.InvocationContext; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.model.chat.listener.ChatModelResponse; +import dev.langchain4j.model.chat.response.ChatResponse; +import dev.langchain4j.model.output.Response; + +/** + * Simple (Chat)Response interceptor, to be applied directly on the model. + */ +@Interceptor +@ResponseInterceptorBinding +@Priority(0) +public class ResponseInterceptor extends ResponseInterceptorBase { + + @AroundInvoke + public Object intercept(InvocationContext context) throws Exception { + Object result = context.proceed(); + ResponseRecord rr = null; + if (result instanceof Response response) { + Object content = response.content(); + if (content instanceof AiMessage am) { + rr = new ResponseRecord(getModel(context.getTarget()), am, response.tokenUsage(), response.finishReason(), + response.metadata()); + } + } else if (result instanceof ChatResponse response) { + rr = new ResponseRecord(getModel(context.getTarget()), response.aiMessage(), response.tokenUsage(), + response.finishReason(), Map.of()); + } else if (result instanceof ChatModelResponse response) { + rr = new ResponseRecord(response.model(), response.aiMessage(), response.tokenUsage(), response.finishReason(), + Map.of("id", response.id())); + } + if (rr != null) { + for (ResponseListener l : getListeners()) { + l.onResponse(rr); + } + } + return result; + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/response/ResponseInterceptorBase.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/response/ResponseInterceptorBase.java new file mode 100644 index 000000000..b3ace1fb7 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/response/ResponseInterceptorBase.java @@ -0,0 +1,41 @@ +package io.quarkiverse.langchain4j.response; + +import java.lang.reflect.Method; +import java.util.Comparator; +import java.util.List; + +import jakarta.enterprise.inject.Any; +import jakarta.enterprise.inject.spi.CDI; + +/** + * Simple (Chat)Response interceptor base, to be applied directly on the model. + */ +public abstract class ResponseInterceptorBase { + + private volatile String model; + private volatile List listeners; + + // TODO -- uh uh ... reflection ... puke + protected String getModel(Object target) { + if (model == null) { + try { + Class clazz = target.getClass(); + Method method = clazz.getMethod("modelName"); + model = (String) method.invoke(target); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + return model; + } + + protected List getListeners() { + if (listeners == null) { + listeners = CDI.current().select(ResponseListener.class, Any.Literal.INSTANCE) + .stream() + .sorted(Comparator.comparing(ResponseListener::order)) + .toList(); + } + return listeners; + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/response/ResponseInterceptorBinding.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/response/ResponseInterceptorBinding.java new file mode 100644 index 000000000..986c1d2db --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/response/ResponseInterceptorBinding.java @@ -0,0 +1,14 @@ +package io.quarkiverse.langchain4j.response; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import jakarta.interceptor.InterceptorBinding; + +@InterceptorBinding +@Target({ ElementType.TYPE, ElementType.METHOD }) +@Retention(RetentionPolicy.RUNTIME) +public @interface ResponseInterceptorBinding { +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/response/ResponseInterceptorBindingSource.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/response/ResponseInterceptorBindingSource.java new file mode 100644 index 000000000..05fac2f3d --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/response/ResponseInterceptorBindingSource.java @@ -0,0 +1,5 @@ +package io.quarkiverse.langchain4j.response; + +@ResponseInterceptorBinding +public abstract class ResponseInterceptorBindingSource { +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/response/ResponseListener.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/response/ResponseListener.java new file mode 100644 index 000000000..ab7966f68 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/response/ResponseListener.java @@ -0,0 +1,12 @@ +package io.quarkiverse.langchain4j.response; + +/** + * Simple ResponseRecord listener, to be implemented by the (advanced) users. + */ +public interface ResponseListener { + void onResponse(ResponseRecord response); + + default int order() { + return 0; + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/response/ResponseRecord.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/response/ResponseRecord.java new file mode 100644 index 000000000..81c717174 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/response/ResponseRecord.java @@ -0,0 +1,18 @@ +package io.quarkiverse.langchain4j.response; + +import java.util.Map; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.model.output.FinishReason; +import dev.langchain4j.model.output.TokenUsage; + +/** + * Abstract away Response vs ChatResponse. + */ +public record ResponseRecord( + String model, + AiMessage content, + TokenUsage tokenUsage, + FinishReason finishReason, + Map metadata) { +} diff --git a/docs/modules/ROOT/pages/includes/quarkus-langchain4j-core.adoc b/docs/modules/ROOT/pages/includes/quarkus-langchain4j-core.adoc index cd612e282..332fa2c73 100644 --- a/docs/modules/ROOT/pages/includes/quarkus-langchain4j-core.adoc +++ b/docs/modules/ROOT/pages/includes/quarkus-langchain4j-core.adoc @@ -101,6 +101,23 @@ endif::add-copy-button-to-env-var[] |boolean |`true` +a|icon:lock[title=Fixed at build time] [[quarkus-langchain4j-core_quarkus-langchain4j-cost-listener]] [.property-path]##link:#quarkus-langchain4j-core_quarkus-langchain4j-cost-listener[`quarkus.langchain4j.cost-listener`]## + +[.description] +-- +Configuration property to enable or disable generic cost listener + + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_COST_LISTENER+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_COST_LISTENER+++` +endif::add-copy-button-to-env-var[] +-- +|boolean +|`false` + a| [[quarkus-langchain4j-core_quarkus-langchain4j-chat-memory-memory-window-max-messages]] [.property-path]##link:#quarkus-langchain4j-core_quarkus-langchain4j-chat-memory-memory-window-max-messages[`quarkus.langchain4j.chat-memory.memory-window.max-messages`]## [.description] diff --git a/docs/modules/ROOT/pages/includes/quarkus-langchain4j-core_quarkus.langchain4j.adoc b/docs/modules/ROOT/pages/includes/quarkus-langchain4j-core_quarkus.langchain4j.adoc index cd612e282..332fa2c73 100644 --- a/docs/modules/ROOT/pages/includes/quarkus-langchain4j-core_quarkus.langchain4j.adoc +++ b/docs/modules/ROOT/pages/includes/quarkus-langchain4j-core_quarkus.langchain4j.adoc @@ -101,6 +101,23 @@ endif::add-copy-button-to-env-var[] |boolean |`true` +a|icon:lock[title=Fixed at build time] [[quarkus-langchain4j-core_quarkus-langchain4j-cost-listener]] [.property-path]##link:#quarkus-langchain4j-core_quarkus-langchain4j-cost-listener[`quarkus.langchain4j.cost-listener`]## + +[.description] +-- +Configuration property to enable or disable generic cost listener + + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_COST_LISTENER+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_COST_LISTENER+++` +endif::add-copy-button-to-env-var[] +-- +|boolean +|`false` + a| [[quarkus-langchain4j-core_quarkus-langchain4j-chat-memory-memory-window-max-messages]] [.property-path]##link:#quarkus-langchain4j-core_quarkus-langchain4j-chat-memory-memory-window-max-messages[`quarkus.langchain4j.chat-memory.memory-window.max-messages`]## [.description] diff --git a/model-providers/openai/openai-vanilla/deployment/src/main/java/io/quarkiverse/langchain4j/openai/deployment/OpenAiProcessor.java b/model-providers/openai/openai-vanilla/deployment/src/main/java/io/quarkiverse/langchain4j/openai/deployment/OpenAiProcessor.java index 0bdfa48ca..2c0438a3d 100644 --- a/model-providers/openai/openai-vanilla/deployment/src/main/java/io/quarkiverse/langchain4j/openai/deployment/OpenAiProcessor.java +++ b/model-providers/openai/openai-vanilla/deployment/src/main/java/io/quarkiverse/langchain4j/openai/deployment/OpenAiProcessor.java @@ -35,6 +35,7 @@ import io.quarkiverse.langchain4j.openai.QuarkusOpenAiStreamingChatModelBuilderFactory; import io.quarkiverse.langchain4j.openai.runtime.OpenAiRecorder; import io.quarkiverse.langchain4j.openai.runtime.config.LangChain4jOpenAiConfig; +import io.quarkiverse.langchain4j.response.ResponseInterceptorBindingSource; import io.quarkiverse.langchain4j.runtime.NamedConfigUtil; import io.quarkus.arc.deployment.SyntheticBeanBuildItem; import io.quarkus.deployment.annotations.BuildProducer; @@ -96,7 +97,8 @@ void generateBeans(OpenAiRecorder recorder, .scope(ApplicationScoped.class) .addInjectionPoint(ParameterizedType.create(DotNames.CDI_INSTANCE, new Type[] { ClassType.create(DotNames.CHAT_MODEL_LISTENER) }, null)) - .createWith(recorder.chatModel(config, configName)); + .createWith(recorder.chatModel(config, configName)) + .injectInterceptionProxy(ResponseInterceptorBindingSource.class); addQualifierIfNecessary(builder, configName); beanProducer.produce(builder.done()); @@ -122,7 +124,8 @@ void generateBeans(OpenAiRecorder recorder, .defaultBean() .unremovable() .scope(ApplicationScoped.class) - .supplier(recorder.embeddingModel(config, configName)); + .createWith(recorder.embeddingModel(config, configName)) + .injectInterceptionProxy(ResponseInterceptorBindingSource.class); addQualifierIfNecessary(builder, configName); beanProducer.produce(builder.done()); } @@ -136,7 +139,8 @@ void generateBeans(OpenAiRecorder recorder, .setRuntimeInit() .defaultBean() .scope(ApplicationScoped.class) - .supplier(recorder.moderationModel(config, configName)); + .createWith(recorder.moderationModel(config, configName)) + .injectInterceptionProxy(ResponseInterceptorBindingSource.class); addQualifierIfNecessary(builder, configName); beanProducer.produce(builder.done()); } @@ -150,7 +154,8 @@ void generateBeans(OpenAiRecorder recorder, .setRuntimeInit() .defaultBean() .scope(ApplicationScoped.class) - .supplier(recorder.imageModel(config, configName)); + .createWith(recorder.imageModel(config, configName)) + .injectInterceptionProxy(ResponseInterceptorBindingSource.class); addQualifierIfNecessary(builder, configName); beanProducer.produce(builder.done()); } diff --git a/model-providers/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/OpenAiRecorder.java b/model-providers/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/OpenAiRecorder.java index 625315a7f..5033e2112 100644 --- a/model-providers/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/OpenAiRecorder.java +++ b/model-providers/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/OpenAiRecorder.java @@ -44,6 +44,7 @@ import io.quarkiverse.langchain4j.openai.runtime.config.LangChain4jOpenAiConfig; import io.quarkiverse.langchain4j.openai.runtime.config.ModerationModelConfig; import io.quarkiverse.langchain4j.runtime.NamedConfigUtil; +import io.quarkus.arc.InterceptionProxy; import io.quarkus.arc.SyntheticCreationalContext; import io.quarkus.runtime.ShutdownContext; import io.quarkus.runtime.annotations.Recorder; @@ -102,7 +103,8 @@ public Function, ChatLanguageModel public ChatLanguageModel apply(SyntheticCreationalContext context) { builder.listeners(context.getInjectedReference(CHAT_MODEL_LISTENER_TYPE_LITERAL).stream() .collect(Collectors.toList())); - return builder.build(); + InterceptionProxy proxy = context.getInterceptionProxy(); + return proxy.create(builder.build()); } }; } else { @@ -175,7 +177,8 @@ public StreamingChatLanguageModel apply( } } - public Supplier embeddingModel(LangChain4jOpenAiConfig runtimeConfig, String configName) { + public Function, EmbeddingModel> embeddingModel( + LangChain4jOpenAiConfig runtimeConfig, String configName) { LangChain4jOpenAiConfig.OpenAiConfig openAiConfig = correspondingOpenAiConfig(runtimeConfig, configName); if (openAiConfig.enableIntegration()) { @@ -206,17 +209,18 @@ public Supplier embeddingModel(LangChain4jOpenAiConfig runtimeCo new InetSocketAddress(host, openAiConfig.proxyPort()))); }); - return new Supplier<>() { + return new Function<>() { @Override - public EmbeddingModel get() { - return builder.build(); + public EmbeddingModel apply(SyntheticCreationalContext context) { + InterceptionProxy proxy = context.getInterceptionProxy(); + return proxy.create(builder.build()); } }; } else { - return new Supplier<>() { + return new Function<>() { @Override - public EmbeddingModel get() { + public EmbeddingModel apply(SyntheticCreationalContext context) { return new DisabledEmbeddingModel(); } @@ -224,7 +228,8 @@ public EmbeddingModel get() { } } - public Supplier moderationModel(LangChain4jOpenAiConfig runtimeConfig, String configName) { + public Function, ModerationModel> moderationModel( + LangChain4jOpenAiConfig runtimeConfig, String configName) { LangChain4jOpenAiConfig.OpenAiConfig openAiConfig = correspondingOpenAiConfig(runtimeConfig, configName); if (openAiConfig.enableIntegration()) { @@ -251,17 +256,18 @@ public Supplier moderationModel(LangChain4jOpenAiConfig runtime new InetSocketAddress(host, openAiConfig.proxyPort()))); }); - return new Supplier<>() { + return new Function<>() { @Override - public ModerationModel get() { - return builder.build(); + public ModerationModel apply(SyntheticCreationalContext context) { + InterceptionProxy proxy = context.getInterceptionProxy(); + return proxy.create(builder.build()); } }; } else { - return new Supplier<>() { + return new Function<>() { @Override - public ModerationModel get() { + public ModerationModel apply(SyntheticCreationalContext context) { return new DisabledModerationModel(); } @@ -269,7 +275,8 @@ public ModerationModel get() { } } - public Supplier imageModel(LangChain4jOpenAiConfig runtimeConfig, String configName) { + public Function, ImageModel> imageModel(LangChain4jOpenAiConfig runtimeConfig, + String configName) { LangChain4jOpenAiConfig.OpenAiConfig openAiConfig = correspondingOpenAiConfig(runtimeConfig, configName); if (openAiConfig.enableIntegration()) { @@ -316,17 +323,18 @@ public Optional get() { builder.persistDirectory(persistDirectory); - return new Supplier<>() { + return new Function<>() { @Override - public ImageModel get() { - return builder.build(); + public ImageModel apply(SyntheticCreationalContext context) { + InterceptionProxy proxy = context.getInterceptionProxy(); + return proxy.create(builder.build()); } }; } else { - return new Supplier<>() { + return new Function<>() { @Override - public ImageModel get() { + public ImageModel apply(SyntheticCreationalContext context) { return new DisabledImageModel(); } }; diff --git a/model-providers/openai/openai-vanilla/runtime/src/test/java/io/quarkiverse/langchain4j/openai/runtime/DisabledModelsOpenAiRecorderTest.java b/model-providers/openai/openai-vanilla/runtime/src/test/java/io/quarkiverse/langchain4j/openai/runtime/DisabledModelsOpenAiRecorderTest.java index 6f0dfe3b6..81d97d90e 100644 --- a/model-providers/openai/openai-vanilla/runtime/src/test/java/io/quarkiverse/langchain4j/openai/runtime/DisabledModelsOpenAiRecorderTest.java +++ b/model-providers/openai/openai-vanilla/runtime/src/test/java/io/quarkiverse/langchain4j/openai/runtime/DisabledModelsOpenAiRecorderTest.java @@ -45,21 +45,21 @@ void disabledStreamingChatModel() { @Test void disabledEmbeddingModel() { - assertThat(recorder.embeddingModel(config, NamedConfigUtil.DEFAULT_NAME).get()) + assertThat(recorder.embeddingModel(config, NamedConfigUtil.DEFAULT_NAME).apply(null)) .isNotNull() .isExactlyInstanceOf(DisabledEmbeddingModel.class); } @Test void disabledImageModel() { - assertThat(recorder.imageModel(config, NamedConfigUtil.DEFAULT_NAME).get()) + assertThat(recorder.imageModel(config, NamedConfigUtil.DEFAULT_NAME).apply(null)) .isNotNull() .isExactlyInstanceOf(DisabledImageModel.class); } @Test void disabledModerationModel() { - assertThat(recorder.moderationModel(config, NamedConfigUtil.DEFAULT_NAME).get()) + assertThat(recorder.moderationModel(config, NamedConfigUtil.DEFAULT_NAME).apply(null)) .isNotNull() .isExactlyInstanceOf(DisabledModerationModel.class); } diff --git a/samples/sql-chatbot/src/main/java/io/quarkiverse/langchain4j/sample/chatbot/MovieMuseCostListener.java b/samples/sql-chatbot/src/main/java/io/quarkiverse/langchain4j/sample/chatbot/MovieMuseCostListener.java new file mode 100644 index 000000000..12de240c8 --- /dev/null +++ b/samples/sql-chatbot/src/main/java/io/quarkiverse/langchain4j/sample/chatbot/MovieMuseCostListener.java @@ -0,0 +1,15 @@ +package io.quarkiverse.langchain4j.sample.chatbot; + +import dev.langchain4j.model.output.TokenUsage; +import io.quarkiverse.langchain4j.cost.Cost; +import io.quarkiverse.langchain4j.cost.CostListener; +import jakarta.enterprise.context.ApplicationScoped; + +@ApplicationScoped +public class MovieMuseCostListener implements CostListener { + public void handleCost(String model, TokenUsage tokenUsage, Cost cost) { + System.out.println("model = " + model); + System.out.println("tokenUsage = " + tokenUsage); + System.out.println("cost = " + cost); + } +} diff --git a/samples/sql-chatbot/src/main/resources/application.properties b/samples/sql-chatbot/src/main/resources/application.properties index 8fb28bb44..75ba3b789 100644 --- a/samples/sql-chatbot/src/main/resources/application.properties +++ b/samples/sql-chatbot/src/main/resources/application.properties @@ -1,6 +1,8 @@ quarkus.langchain4j.timeout=60s csv.file=src/main/resources/data/movies.csv +quarkus.langchain4j.cost-listener=true + quarkus.hibernate-orm.database.generation=drop-and-create # if you want to log the requests and responses that go to OpenAI: