From 9d771142f282346bcaedfd8b0f50070863664f4e Mon Sep 17 00:00:00 2001 From: pkt1583 Date: Thu, 22 Feb 2024 14:27:44 +0530 Subject: [PATCH 1/3] Extending to Azure OpenAI implementation --- evals/completion_fns/openai.py | 5 ++-- evals/openai_client/__init__.py | 0 evals/openai_client/openai_client_provider.py | 27 +++++++++++++++++++ evals/registry.py | 9 +++++-- 4 files changed, 37 insertions(+), 4 deletions(-) create mode 100644 evals/openai_client/__init__.py create mode 100644 evals/openai_client/openai_client_provider.py diff --git a/evals/completion_fns/openai.py b/evals/completion_fns/openai.py index ed50818630..c2a96c662f 100644 --- a/evals/completion_fns/openai.py +++ b/evals/completion_fns/openai.py @@ -4,6 +4,7 @@ from evals.api import CompletionFn, CompletionResult from evals.base import CompletionFnSpec +from evals.openai_client.openai_client_provider import OpenAIClientProvider from evals.prompt.base import ( ChatCompletionPrompt, CompletionPrompt, @@ -82,7 +83,7 @@ def __call__( openai_create_prompt: OpenAICreatePrompt = prompt.to_formatted_prompt() result = openai_completion_create_retrying( - OpenAI(api_key=self.api_key, base_url=self.api_base), + OpenAIClientProvider(api_key=self.api_key).get_client(), model=self.model, prompt=openai_create_prompt, **{**kwargs, **self.extra_options}, @@ -127,7 +128,7 @@ def __call__( openai_create_prompt: OpenAICreateChatPrompt = prompt.to_formatted_prompt() result = openai_chat_completion_create_retrying( - OpenAI(api_key=self.api_key, base_url=self.api_base), + OpenAIClientProvider(api_key=self.api_key).get_client(), model=self.model, messages=openai_create_prompt, **{**kwargs, **self.extra_options}, diff --git a/evals/openai_client/__init__.py b/evals/openai_client/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/evals/openai_client/openai_client_provider.py b/evals/openai_client/openai_client_provider.py new file mode 100644 index 0000000000..fbfd312db3 --- /dev/null +++ b/evals/openai_client/openai_client_provider.py @@ -0,0 +1,27 @@ +import os +from abc import ABC + +from openai import OpenAI +from openai.lib.azure import AzureOpenAI + + +class OpenAIClientProvider(ABC): + def __init__(self, api_key, **kwargs): + self.api_key = api_key or os.environ.get("OPENAI_API_KEY") + self.azure_endpoint = kwargs.pop("azure_endpoint", None) or os.environ.get( + "OPENAI_AZURE_ENDPOINT" + ) + self.api_version = kwargs.pop("opena_api_version", None) or os.environ.get( + "OPENAI_API_VERSION" + ) + + self.kwargs = kwargs + + def get_client(self): + if not self.azure_endpoint: + return OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) + return AzureOpenAI( + api_key=os.environ.get("OPENAI_API_KEY"), + azure_endpoint=os.environ.get("OPENAI_AZURE_ENDPOINT"), + api_version=os.environ.get("OPENAI_API_VERSION"), + ) diff --git a/evals/registry.py b/evals/registry.py index cb37791cbc..9b48129590 100644 --- a/evals/registry.py +++ b/evals/registry.py @@ -15,15 +15,20 @@ import openai import yaml -from openai import OpenAI + from evals import OpenAIChatCompletionFn, OpenAICompletionFn from evals.api import CompletionFn, DummyCompletionFn from evals.base import BaseEvalSpec, CompletionFnSpec, EvalSetSpec, EvalSpec from evals.elsuite.modelgraded.base import ModelGradedSpec +from evals.openai_client.openai_client_provider import OpenAIClientProvider from evals.utils.misc import make_object -client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) +client = OpenAIClientProvider( + api_key=os.environ.get("OPENAI_API_KEY"), + azure_endpoint=os.environ.get("OPENAI_AZURE_ENDPOINT"), + api_version=os.environ.get("OPENAI_API_VERSION"), +).get_client() logger = logging.getLogger(__name__) From cc12b2f88817a4cdd67f8ddb40f3518215c9a77f Mon Sep 17 00:00:00 2001 From: pkt1583 Date: Fri, 23 Feb 2024 09:38:16 +0530 Subject: [PATCH 2/3] Removed passing of Azure environment variable as that is taken care of in ClientProvider --- evals/registry.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/evals/registry.py b/evals/registry.py index 9b48129590..66c17a1288 100644 --- a/evals/registry.py +++ b/evals/registry.py @@ -24,11 +24,7 @@ from evals.openai_client.openai_client_provider import OpenAIClientProvider from evals.utils.misc import make_object -client = OpenAIClientProvider( - api_key=os.environ.get("OPENAI_API_KEY"), - azure_endpoint=os.environ.get("OPENAI_AZURE_ENDPOINT"), - api_version=os.environ.get("OPENAI_API_VERSION"), -).get_client() +client = OpenAIClientProvider(api_key=os.environ.get("OPENAI_API_KEY")).get_client() logger = logging.getLogger(__name__) From 560abde422363e0c18c7720fb0b9ec1141755d5c Mon Sep 17 00:00:00 2001 From: pkt1583 Date: Fri, 23 Feb 2024 12:31:00 +0530 Subject: [PATCH 3/3] Added Azure deployment parameter --- evals/openai_client/openai_client_provider.py | 1 + 1 file changed, 1 insertion(+) diff --git a/evals/openai_client/openai_client_provider.py b/evals/openai_client/openai_client_provider.py index fbfd312db3..fdf142542a 100644 --- a/evals/openai_client/openai_client_provider.py +++ b/evals/openai_client/openai_client_provider.py @@ -24,4 +24,5 @@ def get_client(self): api_key=os.environ.get("OPENAI_API_KEY"), azure_endpoint=os.environ.get("OPENAI_AZURE_ENDPOINT"), api_version=os.environ.get("OPENAI_API_VERSION"), + azure_deployment=os.environ.get("OPENAI_AZURE_DEPLOYMENT"), )