Skip to content

Commit

Permalink
[OpenAI] Add base url (#1757)
Browse files Browse the repository at this point in the history
* [OpenAI] Add base url

* [OpenAI] Wip
  • Loading branch information
assouktim authored Oct 10, 2024
1 parent 67b96ee commit 903e3c2
Show file tree
Hide file tree
Showing 14 changed files with 92 additions and 35 deletions.
3 changes: 2 additions & 1 deletion bot/admin/server/src/test/kotlin/service/RAGServiceTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ class RAGServiceTest : AbstractTest() {
apiKey = "apikey",
model = MODEL,
prompt = PROMPT,
temperature = TEMPERATURE
temperature = TEMPERATURE,
baseUrl = "https://api.openai.com/v1"
),
emSetting = AzureOpenAIEMSettingDTO(
apiKey = "apiKey",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ class RAGValidationServiceTest {
}

private val openAILLMSetting = OpenAILLMSetting(
apiKey = "123-abc", model = "unavailable-model", temperature = "0.4", prompt = "How to bike in the rain"
apiKey = "123-abc", model = "unavailable-model", temperature = "0.4", prompt = "How to bike in the rain",
baseUrl = "https://api.openai.com/v1",
)

private val azureOpenAIEMSetting = AzureOpenAIEMSettingDTO(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,23 +40,24 @@ Answer in '{{locale}}' (language locale).

export const EngineConfigurations: EnginesConfiguration[] = [
{
label: 'OpenAi',
label: 'OpenAI',
key: LLMProvider.OpenAI,
params: [
{ key: 'apiKey', label: 'Api key', type: 'obfuscated' },
{ key: 'baseUrl', label: 'Base url', type: 'text', defaultValue: 'https://api.openai.com/v1' },
{ key: 'model', label: 'Model name', type: 'openlist', source: OpenAIModelsList },
{ key: 'temperature', label: 'Temperature', type: 'number', inputScale: 'fullwidth' },
{ key: 'prompt', label: 'Prompt', type: 'prompt', inputScale: 'fullwidth', defaultValue: DefaultPrompt }
]
},
{
label: 'Azure OpenAi',
label: 'Azure OpenAI',
key: LLMProvider.AzureOpenAIService,
params: [
{ key: 'apiKey', label: 'Api key', type: 'obfuscated' },
{ key: 'apiVersion', label: 'Api version', type: 'openlist', source: AzureOpenAiApiVersionsList },
{ key: 'deploymentName', label: 'Deployment name', type: 'text' },
{ key: 'apiBase', label: 'Private endpoint base url', type: 'obfuscated' },
{ key: 'apiBase', label: 'Base url', type: 'obfuscated' },
{ key: 'temperature', label: 'Temperature', type: 'number', inputScale: 'fullwidth' },
{ key: 'prompt', label: 'Prompt', type: 'prompt', inputScale: 'fullwidth', defaultValue: DefaultPrompt }
]
Expand All @@ -65,7 +66,7 @@ export const EngineConfigurations: EnginesConfiguration[] = [
label: 'Ollama',
key: LLMProvider.Ollama,
params: [
{ key: 'baseUrl', label: 'BaseUrl', type: 'text', defaultValue: 'http://localhost:11434' },
{ key: 'baseUrl', label: 'Base url', type: 'text', defaultValue: 'http://localhost:11434' },
{ key: 'model', label: 'Model', type: 'openlist', source: OllamaLlmModelsList, defaultValue: 'llama2' },
{ key: 'temperature', label: 'Temperature', type: 'number', inputScale: 'fullwidth', defaultValue: 0.7 },
{ key: 'prompt', label: 'Prompt', type: 'prompt', inputScale: 'fullwidth', defaultValue: DefaultPrompt }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,23 +39,24 @@ Answer in {locale}.

const EnginesConfigurations_Llm: EnginesConfiguration[] = [
{
label: 'OpenAi',
label: 'OpenAI',
key: LLMProvider.OpenAI,
params: [
{ key: 'apiKey', label: 'Api key', type: 'obfuscated' },
{ key: 'baseUrl', label: 'Base url', type: 'text', defaultValue: 'https://api.openai.com/v1' },
{ key: 'model', label: 'Model name', type: 'openlist', source: OpenAIModelsList },
{ key: 'temperature', label: 'Temperature', type: 'number', inputScale: 'fullwidth' },
{ key: 'prompt', label: 'Prompt', type: 'prompt', inputScale: 'fullwidth', defaultValue: DefaultPrompt }
]
},
{
label: 'Azure OpenAi',
label: 'Azure OpenAI',
key: LLMProvider.AzureOpenAIService,
params: [
{ key: 'apiKey', label: 'Api key', type: 'obfuscated' },
{ key: 'apiVersion', label: 'Api version', type: 'openlist', source: AzureOpenAiApiVersionsList },
{ key: 'deploymentName', label: 'Deployment name', type: 'text' },
{ key: 'apiBase', label: 'Private endpoint base url', type: 'obfuscated' },
{ key: 'apiBase', label: 'Base url', type: 'obfuscated' },
{ key: 'temperature', label: 'Temperature', type: 'number', inputScale: 'fullwidth' },
{ key: 'prompt', label: 'Prompt', type: 'prompt', inputScale: 'fullwidth', defaultValue: DefaultPrompt }
]
Expand All @@ -74,21 +75,22 @@ const EnginesConfigurations_Llm: EnginesConfiguration[] = [

const EnginesConfigurations_Embedding: EnginesConfiguration[] = [
{
label: 'OpenAi',
label: 'OpenAI',
key: LLMProvider.OpenAI,
params: [
{ key: 'apiKey', label: 'Api key', type: 'obfuscated' },
{ key: 'baseUrl', label: 'Base url', type: 'text', defaultValue: 'https://api.openai.com/v1' },
{ key: 'model', label: 'Model name', type: 'openlist', source: OpenAIEmbeddingModel }
]
},
{
label: 'Azure OpenAi',
label: 'Azure OpenAI',
key: LLMProvider.AzureOpenAIService,
params: [
{ key: 'apiKey', label: 'Api key', type: 'obfuscated' },
{ key: 'apiVersion', label: 'Api version', type: 'openlist', source: AzureOpenAiApiVersionsList },
{ key: 'deploymentName', label: 'Deployment name', type: 'text' },
{ key: 'apiBase', label: 'Private endpoint base url', type: 'obfuscated' }
{ key: 'apiBase', label: 'Base url', type: 'obfuscated' }
]
},
{
Expand Down
6 changes: 5 additions & 1 deletion bot/admin/web/src/app/shared/model/ai-settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,16 @@ export const AzureOpenAiApiVersionsList: string[] = [
];

export const OpenAIModelsList: string[] = [
'gpt-4o',
'gpt-4o-mini',

'gpt-4',
'gpt-4-0314',
'gpt-4-0613',
'gpt-4-32k',
'gpt-4-32k-0314',
'gpt-4-32k-0613',
'gpt-4-turbo',

'gpt-3.5-turbo',
'gpt-3.5-turbo-0613',
Expand All @@ -71,7 +75,7 @@ export const OpenAIModelsList: string[] = [
'davinci-002'
];

export const OpenAIEmbeddingModel: string[] = ['text-embedding-ada-002'];
export const OpenAIEmbeddingModel: string[] = ['text-embedding-3-small', 'text-embedding-3-large', 'text-embedding-ada-002'];

export const OllamaLlmModelsList: string[] = ['llama2', 'llama3', 'llama3.1', 'llama3.1:8b', 'llama3.2'];

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,13 @@ internal class BotRAGConfigurationMongoDAOTest : AbstractTest() {
apiKey = RawSecretKey("apiKey1"),
model = "modelName1",
temperature = "1F",
prompt = "prompt1"
prompt = "prompt1",
baseUrl = "https://api.openai.com/v1"
),
emSetting = OpenAIEMSetting(
apiKey = RawSecretKey("apiKey1"),
model = "modelName1"
model = "modelName1",
baseUrl = "https://api.openai.com/v1"
),
noAnswerSentence = "no answer sentence"
)
Expand All @@ -73,11 +75,13 @@ internal class BotRAGConfigurationMongoDAOTest : AbstractTest() {
llmSetting = OpenAILLMSetting(
apiKey = RawSecretKey("apiKey1"),
model = "modelName1",
baseUrl = "https://api.openai.com/v1",
temperature = "1F",
prompt = "prompt1"
),
emSetting = OpenAIEMSetting(
apiKey = RawSecretKey("apiKey1"),
baseUrl = "https://api.openai.com/v1",
model = "modelName1"
),
noAnswerSentence = "no answer sentence1"
Expand All @@ -92,10 +96,12 @@ internal class BotRAGConfigurationMongoDAOTest : AbstractTest() {
apiKey = RawSecretKey("apiKey1"),
model = "modelName1",
temperature = "1F",
baseUrl = "https://api.openai.com/v1",
prompt = "prompt1"
),
emSetting = OpenAIEMSetting(
apiKey = RawSecretKey("apiKey1"),
baseUrl = "https://api.openai.com/v1",
model = "modelName1"
),
noAnswerSentence = "no answer sentence1"
Expand Down Expand Up @@ -124,10 +130,12 @@ internal class BotRAGConfigurationMongoDAOTest : AbstractTest() {
apiKey = RawSecretKey("apiKey1"),
model = "modelName1",
temperature = "1F",
baseUrl = "https://api.openai.com/v1",
prompt = "prompt1"
),
emSetting = OpenAIEMSetting(
apiKey = RawSecretKey("apiKey1"),
baseUrl = "https://api.openai.com/v1",
model = "modelName1"
),
noAnswerSentence = "no answer sentence"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,20 @@ object EMSettingMapper {
with(entity){
when(this){
is OpenAIEMSetting ->
OpenAIEMSetting(SecurityUtils.fetchSecretKeyValue(apiKey), model)
OpenAIEMSetting(
apiKey = SecurityUtils.fetchSecretKeyValue(apiKey),
model = model,
baseUrl = baseUrl
)
is AzureOpenAIEMSetting ->
AzureOpenAIEMSetting(SecurityUtils.fetchSecretKeyValue(apiKey), apiBase, deploymentName, apiVersion)
AzureOpenAIEMSetting(
apiKey = SecurityUtils.fetchSecretKeyValue(apiKey),
apiBase = apiBase,
deploymentName = deploymentName,
apiVersion = apiVersion
)
is OllamaEMSetting ->
OllamaEMSetting(model, baseUrl)
OllamaEMSetting(model = model, baseUrl = baseUrl)
else ->
throw IllegalArgumentException("Unsupported EM Setting")
}
Expand All @@ -56,16 +65,20 @@ object EMSettingMapper {
with(dto){
when(this){
is OpenAIEMSetting ->
OpenAIEMSetting(SecurityUtils.createSecretKey(namespace, botId, feature, apiKey), model)
OpenAIEMSetting(
apiKey = SecurityUtils.createSecretKey(namespace, botId, feature, apiKey),
model = model,
baseUrl = baseUrl
)
is AzureOpenAIEMSetting ->
AzureOpenAIEMSetting(
SecurityUtils.createSecretKey(namespace, botId, feature, apiKey),
apiBase,
deploymentName,
apiVersion
apiBase = apiBase,
deploymentName = deploymentName,
apiVersion = apiVersion
)
is OllamaEMSetting ->
OllamaEMSetting(model, baseUrl)
OllamaEMSetting(model = model, baseUrl = baseUrl)
else ->
throw IllegalArgumentException("Unsupported EM Setting")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,29 @@ object LLMSettingMapper {

when (this) {
is OpenAILLMSetting ->
OpenAILLMSetting(SecurityUtils.fetchSecretKeyValue(apiKey), temperature, prompt, model)
OpenAILLMSetting(
apiKey = SecurityUtils.fetchSecretKeyValue(apiKey),
temperature = temperature,
prompt = prompt,
model = model,
baseUrl = baseUrl
)
is AzureOpenAILLMSetting ->
AzureOpenAILLMSetting(
SecurityUtils.fetchSecretKeyValue(apiKey),
temperature,
prompt,
apiBase,
deploymentName,
apiVersion
apiKey = SecurityUtils.fetchSecretKeyValue(apiKey),
temperature = temperature,
prompt = prompt,
apiBase = apiBase,
deploymentName = deploymentName,
apiVersion = apiVersion
)
is OllamaLLMSetting ->
OllamaLLMSetting(temperature, prompt, model, baseUrl)
OllamaLLMSetting(
temperature = temperature,
prompt = prompt,
model = model,
baseUrl = baseUrl
)
else ->
throw IllegalArgumentException("Unsupported LLM Setting")
}
Expand All @@ -64,10 +75,11 @@ object LLMSettingMapper {
when (this) {
is OpenAILLMSetting ->
OpenAILLMSetting(
SecurityUtils.createSecretKey(namespace, botId, feature, apiKey),
temperature,
prompt,
model
apiKey = SecurityUtils.createSecretKey(namespace, botId, feature, apiKey),
temperature = temperature,
prompt = prompt,
model = model,
baseUrl = baseUrl
)
is AzureOpenAILLMSetting ->
AzureOpenAILLMSetting(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package ai.tock.genai.orchestratorcore.models.em
data class OpenAIEMSetting<T>(
override val apiKey: T,
val model: String,
val baseUrl: String,
) : EMSettingBase<T>(EMProvider.OpenAI, apiKey)

typealias OpenAIEMSettingDTO = OpenAIEMSetting<String>
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ data class OpenAILLMSetting<T>(
override val temperature: String,
override val prompt: String,
val model: String,
val baseUrl: String,
) : LLMSettingBase<T>(LLMProvider.OpenAI, apiKey, temperature, prompt) {
override fun copyWithTemperature(temperature: String): LLMSettingBase<T> {
return this.copy(temperature=temperature)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,8 @@ class OpenAIEMSetting(BaseEMSetting):
examples=[RawSecretKey(value='ab7-14Ed2-dfg2F-A1IV4B')]
)
model: str = Field(description='The model id', examples=['text-embedding-ada-002'])
base_url: str = Field(
description='The OpenAI endpoint base URL',
examples=["https://api.openai.com/v1"],
default="https://api.openai.com/v1"
)
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,8 @@ class OpenAILLMSetting(BaseLLMSetting):
model: str = Field(
description='The model id', examples=['gpt-3.5-turbo'], min_length=1
)
base_url: str = Field(
description='The OpenAI endpoint base URL',
examples=["https://api.openai.com/v1"],
default="https://api.openai.com/v1"
)
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from langchain.embeddings.base import Embeddings
from langchain_openai import OpenAIEmbeddings
from openai import base_url

from gen_ai_orchestrator.configurations.environment.settings import (
application_settings,
Expand Down Expand Up @@ -46,6 +47,7 @@ class OpenAIEMFactory(LangChainEMFactory):
def get_embedding_model(self) -> Embeddings:
return OpenAIEmbeddings(
openai_api_key=fetch_secret_key_value(self.setting.api_key),
base_url=self.setting.base_url,
model=self.setting.model,
timeout=application_settings.em_provider_timeout,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class OpenAILLMFactory(LangChainLLMFactory):
def get_language_model(self) -> BaseLanguageModel:
return ChatOpenAI(
openai_api_key=fetch_secret_key_value(self.setting.api_key),
base_url=self.setting.base_url,
model_name=self.setting.model,
temperature=self.setting.temperature,
request_timeout=application_settings.llm_provider_timeout,
Expand Down

0 comments on commit 903e3c2

Please sign in to comment.