Skip to content

Commit

Permalink
feat(azure): support for the Realtime API (#1963)
Browse files Browse the repository at this point in the history
  • Loading branch information
kristapratico authored Dec 19, 2024
1 parent 3dee863 commit 9fda141
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 5 deletions.
2 changes: 2 additions & 0 deletions src/openai/_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,15 @@
coerce_integer as coerce_integer,
file_from_path as file_from_path,
parse_datetime as parse_datetime,
is_azure_client as is_azure_client,
strip_not_given as strip_not_given,
deepcopy_minimal as deepcopy_minimal,
get_async_library as get_async_library,
maybe_coerce_float as maybe_coerce_float,
get_required_header as get_required_header,
maybe_coerce_boolean as maybe_coerce_boolean,
maybe_coerce_integer as maybe_coerce_integer,
is_async_azure_client as is_async_azure_client,
)
from ._typing import (
is_list_type as is_list_type,
Expand Down
16 changes: 16 additions & 0 deletions src/openai/_utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import inspect
import functools
from typing import (
TYPE_CHECKING,
Any,
Tuple,
Mapping,
Expand All @@ -30,6 +31,9 @@
_SequenceT = TypeVar("_SequenceT", bound=Sequence[object])
CallableT = TypeVar("CallableT", bound=Callable[..., Any])

if TYPE_CHECKING:
from ..lib.azure import AzureOpenAI, AsyncAzureOpenAI


def flatten(t: Iterable[Iterable[_T]]) -> list[_T]:
return [item for sublist in t for item in sublist]
Expand Down Expand Up @@ -412,3 +416,15 @@ def json_safe(data: object) -> object:
return data.isoformat()

return data


def is_azure_client(client: object) -> TypeGuard[AzureOpenAI]:
from ..lib.azure import AzureOpenAI

return isinstance(client, AzureOpenAI)


def is_async_azure_client(client: object) -> TypeGuard[AsyncAzureOpenAI]:
from ..lib.azure import AsyncAzureOpenAI

return isinstance(client, AsyncAzureOpenAI)
32 changes: 31 additions & 1 deletion src/openai/lib/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import httpx

from .._types import NOT_GIVEN, Omit, Timeout, NotGiven
from .._types import NOT_GIVEN, Omit, Query, Timeout, NotGiven
from .._utils import is_given, is_mapping
from .._client import OpenAI, AsyncOpenAI
from .._compat import model_copy
Expand Down Expand Up @@ -307,6 +307,21 @@ def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:

return options

def _configure_realtime(self, model: str, extra_query: Query) -> tuple[Query, dict[str, str]]:
auth_headers = {}
query = {
**extra_query,
"api-version": self._api_version,
"deployment": model,
}
if self.api_key != "<missing API key>":
auth_headers = {"api-key": self.api_key}
else:
token = self._get_azure_ad_token()
if token:
auth_headers = {"Authorization": f"Bearer {token}"}
return query, auth_headers


class AsyncAzureOpenAI(BaseAzureClient[httpx.AsyncClient, AsyncStream[Any]], AsyncOpenAI):
@overload
Expand Down Expand Up @@ -555,3 +570,18 @@ async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOp
raise ValueError("Unable to handle auth")

return options

async def _configure_realtime(self, model: str, extra_query: Query) -> tuple[Query, dict[str, str]]:
auth_headers = {}
query = {
**extra_query,
"api-version": self._api_version,
"deployment": model,
}
if self.api_key != "<missing API key>":
auth_headers = {"api-key": self.api_key}
else:
token = await self._get_azure_ad_token()
if token:
auth_headers = {"Authorization": f"Bearer {token}"}
return query, auth_headers
20 changes: 16 additions & 4 deletions src/openai/resources/beta/realtime/realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
)
from ...._types import NOT_GIVEN, Query, Headers, NotGiven
from ...._utils import (
is_azure_client,
maybe_transform,
strip_not_given,
async_maybe_transform,
is_async_azure_client,
)
from ...._compat import cached_property
from ...._models import construct_type_unchecked
Expand Down Expand Up @@ -319,11 +321,16 @@ async def __aenter__(self) -> AsyncRealtimeConnection:
except ImportError as exc:
raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc

extra_query = self.__extra_query
auth_headers = self.__client.auth_headers
if is_async_azure_client(self.__client):
extra_query, auth_headers = await self.__client._configure_realtime(self.__model, extra_query)

url = self._prepare_url().copy_with(
params={
**self.__client.base_url.params,
"model": self.__model,
**self.__extra_query,
**extra_query,
},
)
log.debug("Connecting to %s", url)
Expand All @@ -336,7 +343,7 @@ async def __aenter__(self) -> AsyncRealtimeConnection:
user_agent_header=self.__client.user_agent,
additional_headers=_merge_mappings(
{
**self.__client.auth_headers,
**auth_headers,
"OpenAI-Beta": "realtime=v1",
},
self.__extra_headers,
Expand Down Expand Up @@ -496,11 +503,16 @@ def __enter__(self) -> RealtimeConnection:
except ImportError as exc:
raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc

extra_query = self.__extra_query
auth_headers = self.__client.auth_headers
if is_azure_client(self.__client):
extra_query, auth_headers = self.__client._configure_realtime(self.__model, extra_query)

url = self._prepare_url().copy_with(
params={
**self.__client.base_url.params,
"model": self.__model,
**self.__extra_query,
**extra_query,
},
)
log.debug("Connecting to %s", url)
Expand All @@ -513,7 +525,7 @@ def __enter__(self) -> RealtimeConnection:
user_agent_header=self.__client.user_agent,
additional_headers=_merge_mappings(
{
**self.__client.auth_headers,
**auth_headers,
"OpenAI-Beta": "realtime=v1",
},
self.__extra_headers,
Expand Down

0 comments on commit 9fda141

Please sign in to comment.