-
Notifications
You must be signed in to change notification settings - Fork 100
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Make response handling more generic. Resulting in generic cost handling.
- Loading branch information
Showing
18 changed files
with
327 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
55 changes: 55 additions & 0 deletions
55
.../runtime/src/main/java/io/quarkiverse/langchain4j/cost/CostEstimatorResponseListener.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
package io.quarkiverse.langchain4j.cost; | ||
|
||
import java.util.ArrayList; | ||
import java.util.Comparator; | ||
import java.util.List; | ||
|
||
import jakarta.inject.Inject; | ||
|
||
import dev.langchain4j.model.output.TokenUsage; | ||
import io.quarkiverse.langchain4j.response.ResponseListener; | ||
import io.quarkiverse.langchain4j.response.ResponseRecord; | ||
import io.quarkus.arc.All; | ||
import io.smallrye.common.annotation.Experimental; | ||
|
||
/** | ||
* Allows for user code to provide a custom strategy for estimating the cost of API calls | ||
*/ | ||
@Experimental("This feature is experimental and the API is subject to change") | ||
public class CostEstimatorResponseListener implements ResponseListener { | ||
|
||
private final CostEstimatorService service; | ||
private final List<CostListener> listeners; | ||
|
||
@Inject | ||
public CostEstimatorResponseListener(CostEstimatorService service, @All List<CostListener> listeners) { | ||
this.service = service; | ||
this.listeners = new ArrayList<>(listeners); | ||
this.listeners.sort(Comparator.comparingInt(CostListener::order)); | ||
} | ||
|
||
@Override | ||
public void onResponse(ResponseRecord rr) { | ||
String model = rr.model(); | ||
TokenUsage tokenUsage = rr.tokenUsage(); | ||
CostEstimator.CostContext context = new MyCostContext(tokenUsage, model); | ||
Cost cost = service.estimate(context); | ||
if (cost != null) { | ||
for (CostListener cl : listeners) { | ||
cl.handleCost(model, tokenUsage, cost); | ||
} | ||
} | ||
} | ||
|
||
private record MyCostContext(TokenUsage tokenUsage, String model) implements CostEstimator.CostContext { | ||
@Override | ||
public Integer inputTokens() { | ||
return tokenUsage().inputTokenCount(); | ||
} | ||
|
||
@Override | ||
public Integer outputTokens() { | ||
return tokenUsage().outputTokenCount(); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
14 changes: 14 additions & 0 deletions
14
core/runtime/src/main/java/io/quarkiverse/langchain4j/cost/CostListener.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
package io.quarkiverse.langchain4j.cost; | ||
|
||
import dev.langchain4j.model.output.TokenUsage; | ||
|
||
/** | ||
* Allows for user code to handle estimate cost; e.g. some simple accounting | ||
*/ | ||
public interface CostListener { | ||
void handleCost(String model, TokenUsage tokenUsage, Cost cost); | ||
|
||
default int order() { | ||
return 0; | ||
} | ||
} |
47 changes: 47 additions & 0 deletions
47
core/runtime/src/main/java/io/quarkiverse/langchain4j/response/ResponseInterceptor.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
package io.quarkiverse.langchain4j.response; | ||
|
||
import java.util.Map; | ||
|
||
import jakarta.annotation.Priority; | ||
import jakarta.interceptor.AroundInvoke; | ||
import jakarta.interceptor.Interceptor; | ||
import jakarta.interceptor.InvocationContext; | ||
|
||
import dev.langchain4j.data.message.AiMessage; | ||
import dev.langchain4j.model.chat.listener.ChatModelResponse; | ||
import dev.langchain4j.model.chat.response.ChatResponse; | ||
import dev.langchain4j.model.output.Response; | ||
|
||
/** | ||
* Simple (Chat)Response interceptor, to be applied directly on the model. | ||
*/ | ||
@Interceptor | ||
@ResponseInterceptorBinding | ||
@Priority(0) | ||
public class ResponseInterceptor extends ResponseInterceptorBase { | ||
|
||
@AroundInvoke | ||
public Object intercept(InvocationContext context) throws Exception { | ||
Object result = context.proceed(); | ||
ResponseRecord rr = null; | ||
if (result instanceof Response<?> response) { | ||
Object content = response.content(); | ||
if (content instanceof AiMessage am) { | ||
rr = new ResponseRecord(getModel(context.getTarget()), am, response.tokenUsage(), response.finishReason(), | ||
response.metadata()); | ||
} | ||
} else if (result instanceof ChatResponse response) { | ||
rr = new ResponseRecord(getModel(context.getTarget()), response.aiMessage(), response.tokenUsage(), | ||
response.finishReason(), Map.of()); | ||
} else if (result instanceof ChatModelResponse response) { | ||
rr = new ResponseRecord(response.model(), response.aiMessage(), response.tokenUsage(), response.finishReason(), | ||
Map.of("id", response.id())); | ||
} | ||
if (rr != null) { | ||
for (ResponseListener l : getListeners()) { | ||
l.onResponse(rr); | ||
} | ||
} | ||
return result; | ||
} | ||
} |
41 changes: 41 additions & 0 deletions
41
core/runtime/src/main/java/io/quarkiverse/langchain4j/response/ResponseInterceptorBase.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
package io.quarkiverse.langchain4j.response; | ||
|
||
import java.lang.reflect.Method; | ||
import java.util.Comparator; | ||
import java.util.List; | ||
|
||
import jakarta.enterprise.inject.Any; | ||
import jakarta.enterprise.inject.spi.CDI; | ||
|
||
/** | ||
* Simple (Chat)Response interceptor base, to be applied directly on the model. | ||
*/ | ||
public abstract class ResponseInterceptorBase { | ||
|
||
private volatile String model; | ||
private volatile List<ResponseListener> listeners; | ||
|
||
// TODO -- uh uh ... reflection ... puke | ||
protected String getModel(Object target) { | ||
if (model == null) { | ||
try { | ||
Class<?> clazz = target.getClass(); | ||
Method method = clazz.getMethod("modelName"); | ||
model = (String) method.invoke(target); | ||
} catch (Exception e) { | ||
throw new RuntimeException(e); | ||
} | ||
} | ||
return model; | ||
} | ||
|
||
protected List<ResponseListener> getListeners() { | ||
if (listeners == null) { | ||
listeners = CDI.current().select(ResponseListener.class, Any.Literal.INSTANCE) | ||
.stream() | ||
.sorted(Comparator.comparing(ResponseListener::order)) | ||
.toList(); | ||
} | ||
return listeners; | ||
} | ||
} |
14 changes: 14 additions & 0 deletions
14
...runtime/src/main/java/io/quarkiverse/langchain4j/response/ResponseInterceptorBinding.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
package io.quarkiverse.langchain4j.response; | ||
|
||
import java.lang.annotation.ElementType; | ||
import java.lang.annotation.Retention; | ||
import java.lang.annotation.RetentionPolicy; | ||
import java.lang.annotation.Target; | ||
|
||
import jakarta.interceptor.InterceptorBinding; | ||
|
||
@InterceptorBinding | ||
@Target({ ElementType.TYPE, ElementType.METHOD }) | ||
@Retention(RetentionPolicy.RUNTIME) | ||
public @interface ResponseInterceptorBinding { | ||
} |
5 changes: 5 additions & 0 deletions
5
...e/src/main/java/io/quarkiverse/langchain4j/response/ResponseInterceptorBindingSource.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
package io.quarkiverse.langchain4j.response; | ||
|
||
@ResponseInterceptorBinding | ||
public abstract class ResponseInterceptorBindingSource { | ||
} |
12 changes: 12 additions & 0 deletions
12
core/runtime/src/main/java/io/quarkiverse/langchain4j/response/ResponseListener.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
package io.quarkiverse.langchain4j.response; | ||
|
||
/** | ||
* Simple ResponseRecord listener, to be implemented by the (advanced) users. | ||
*/ | ||
public interface ResponseListener { | ||
void onResponse(ResponseRecord response); | ||
|
||
default int order() { | ||
return 0; | ||
} | ||
} |
18 changes: 18 additions & 0 deletions
18
core/runtime/src/main/java/io/quarkiverse/langchain4j/response/ResponseRecord.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
package io.quarkiverse.langchain4j.response; | ||
|
||
import java.util.Map; | ||
|
||
import dev.langchain4j.data.message.AiMessage; | ||
import dev.langchain4j.model.output.FinishReason; | ||
import dev.langchain4j.model.output.TokenUsage; | ||
|
||
/** | ||
* Abstract away Response vs ChatResponse. | ||
*/ | ||
public record ResponseRecord( | ||
String model, | ||
AiMessage content, | ||
TokenUsage tokenUsage, | ||
FinishReason finishReason, | ||
Map<String, Object> metadata) { | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.