Skip to content

Commit

Permalink
Merge pull request #137 from quarkiverse/azure-remaining
Browse files Browse the repository at this point in the history
Add remaining models to Azure module
  • Loading branch information
geoand authored Dec 12, 2023
2 parents cf88fc3 + ea90644 commit d9b008a
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 29 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package io.quarkiverse.langchain4j.azure.openai.deployment;

import static io.quarkiverse.langchain4j.deployment.Langchain4jDotNames.CHAT_MODEL;
import static io.quarkiverse.langchain4j.deployment.Langchain4jDotNames.EMBEDDING_MODEL;
import static io.quarkiverse.langchain4j.deployment.Langchain4jDotNames.STREAMING_CHAT_MODEL;

import java.util.Optional;

Expand Down Expand Up @@ -43,9 +45,6 @@ public void providerCandidates(BuildProducer<ChatModelProviderCandidateBuildItem
if (config.embeddingModel().enabled().isEmpty() || config.embeddingModel().enabled().get()) {
embeddingProducer.produce(new EmbeddingModelProviderCandidateBuildItem(PROVIDER));
}
if (config.moderationModel().enabled().isEmpty() || config.moderationModel().enabled().get()) {
moderationProducer.produce(new ModerationModelProviderCandidateBuildItem(PROVIDER));
}
}

@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
Expand All @@ -66,33 +65,23 @@ void generateBeans(AzureOpenAiRecorder recorder,
.supplier(recorder.chatModel(config))
.done());

// beanProducer.produce(SyntheticBeanBuildItem
// .configure(STREAMING_CHAT_MODEL)
// .setRuntimeInit()
// .defaultBean()
// .scope(ApplicationScoped.class)
// .supplier(recorder.streamingChatModel(config))
// .done());
beanProducer.produce(SyntheticBeanBuildItem
.configure(STREAMING_CHAT_MODEL)
.setRuntimeInit()
.defaultBean()
.scope(ApplicationScoped.class)
.supplier(recorder.streamingChatModel(config))
.done());
}

if (selectedEmbedding.isPresent() && PROVIDER.equals(selectedEmbedding.get().getProvider())) {
// beanProducer.produce(SyntheticBeanBuildItem
// .configure(EMBEDDING_MODEL)
// .setRuntimeInit()
// .defaultBean()
// .scope(ApplicationScoped.class)
// .supplier(recorder.embeddingModel(config))
// .done());
}

if (selectedModeration.isPresent() && PROVIDER.equals(selectedModeration.get().getProvider())) {
// beanProducer.produce(SyntheticBeanBuildItem
// .configure(MODERATION_MODEL)
// .setRuntimeInit()
// .defaultBean()
// .scope(ApplicationScoped.class)
// .supplier(recorder.moderationModel(config))
// .done());
beanProducer.produce(SyntheticBeanBuildItem
.configure(EMBEDDING_MODEL)
.setRuntimeInit()
.defaultBean()
.scope(ApplicationScoped.class)
.supplier(recorder.embeddingModel(config))
.done());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
import java.util.function.Supplier;

import dev.langchain4j.model.azure.AzureOpenAiChatModel;
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
import dev.langchain4j.model.openai.OpenAiStreamingChatModel;
import io.quarkiverse.langchain4j.azure.openai.runtime.config.ChatModelConfig;
import io.quarkiverse.langchain4j.azure.openai.runtime.config.EmbeddingModelConfig;
import io.quarkiverse.langchain4j.azure.openai.runtime.config.Langchain4jAzureOpenAiConfig;
import io.quarkiverse.langchain4j.openai.QuarkusOpenAiClient;
import io.quarkus.runtime.ShutdownContext;
Expand All @@ -15,8 +18,7 @@ public class AzureOpenAiRecorder {
public Supplier<?> chatModel(Langchain4jAzureOpenAiConfig runtimeConfig) {
ChatModelConfig chatModelConfig = runtimeConfig.chatModel();
var builder = AzureOpenAiChatModel.builder()
.baseUrl(String.format("https://%s.openai.azure.com/openai/deployments/%s", runtimeConfig.resourceName(),
runtimeConfig.deploymentId()))
.baseUrl(getBaseUrl(runtimeConfig))
.apiKey(runtimeConfig.apiKey())
.apiVersion(runtimeConfig.apiVersion())
.timeout(runtimeConfig.timeout())
Expand All @@ -41,6 +43,57 @@ public Object get() {
};
}

public Supplier<?> streamingChatModel(Langchain4jAzureOpenAiConfig runtimeConfig) {
ChatModelConfig chatModelConfig = runtimeConfig.chatModel();
var builder = OpenAiStreamingChatModel.builder()
.baseUrl(getBaseUrl(runtimeConfig))
.apiKey(runtimeConfig.apiKey())
.timeout(runtimeConfig.timeout())
.logRequests(runtimeConfig.logRequests())
.logResponses(runtimeConfig.logResponses())

.temperature(chatModelConfig.temperature())
.topP(chatModelConfig.topP())
.presencePenalty(chatModelConfig.presencePenalty())
.frequencyPenalty(chatModelConfig.frequencyPenalty());

if (chatModelConfig.maxTokens().isPresent()) {
builder.maxTokens(chatModelConfig.maxTokens().get());
}

return new Supplier<>() {
@Override
public Object get() {
return builder.build();
}
};
}

public Supplier<?> embeddingModel(Langchain4jAzureOpenAiConfig runtimeConfig) {
EmbeddingModelConfig embeddingModelConfig = runtimeConfig.embeddingModel();
var builder = OpenAiEmbeddingModel.builder()
.baseUrl(getBaseUrl(runtimeConfig))
.apiKey(runtimeConfig.apiKey())
.timeout(runtimeConfig.timeout())
.maxRetries(runtimeConfig.maxRetries())
.logRequests(runtimeConfig.logRequests())
.logResponses(runtimeConfig.logResponses())

.modelName(embeddingModelConfig.modelName());

return new Supplier<>() {
@Override
public Object get() {
return builder.build();
}
};
}

private String getBaseUrl(Langchain4jAzureOpenAiConfig runtimeConfig) {
return String.format("https://%s.openai.azure.com/openai/deployments/%s", runtimeConfig.resourceName(),
runtimeConfig.deploymentId());
}

public void cleanUp(ShutdownContext shutdown) {
shutdown.addShutdownTask(new Runnable() {
@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package io.quarkiverse.langchain4j.azure.openai.runtime.config;

import io.quarkus.runtime.annotations.ConfigGroup;
import io.smallrye.config.WithDefault;

@ConfigGroup
public interface EmbeddingModelConfig {

/**
* Model name to use
*/
@WithDefault("text-embedding-ada-002")
String modelName();
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,9 @@ public interface Langchain4jAzureOpenAiConfig {
* Chat model related settings
*/
ChatModelConfig chatModel();

/**
* Embedding model related settings
*/
EmbeddingModelConfig embeddingModel();
}

0 comments on commit d9b008a

Please sign in to comment.