From 779ef418380013a7e117d85e94677752148d1554 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Tue, 23 Jul 2024 17:55:18 +0200 Subject: [PATCH] Update `InferenceEndpointsLLM` to use `chat_completion` method (#815) * Update `InferenceEndpointsLLM` to use `chat_completion` method * Fix unit test * Fix unit test for python 3.8 --- pyproject.toml | 2 +- .../llms/huggingface/inference_endpoints.py | 444 ++++++++++++------ .../huggingface/test_inference_endpoints.py | 186 +++++--- 3 files changed, 409 insertions(+), 223 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b94387ff51..43a084df4f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,7 +75,7 @@ anthropic = ["anthropic >= 0.20.0"] argilla = ["argilla >= 1.29.0"] cohere = ["cohere >= 5.2.0"] groq = ["groq >= 0.4.1"] -hf-inference-endpoints = ["huggingface_hub >= 0.19.0"] +hf-inference-endpoints = ["huggingface_hub >= 0.22.0"] hf-transformers = ["transformers >= 4.34.1", "torch >= 2.0.0"] instructor = ["instructor >= 1.2.3"] litellm = ["litellm >= 1.30.0"] diff --git a/src/distilabel/llms/huggingface/inference_endpoints.py b/src/distilabel/llms/huggingface/inference_endpoints.py index 16139b72c3..55df0f234a 100644 --- a/src/distilabel/llms/huggingface/inference_endpoints.py +++ b/src/distilabel/llms/huggingface/inference_endpoints.py @@ -14,8 +14,9 @@ import os import random +import sys import warnings -from typing import TYPE_CHECKING, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union from pydantic import ( Field, @@ -25,7 +26,7 @@ model_validator, validate_call, ) -from typing_extensions import override +from typing_extensions import Annotated, override from distilabel.llms.base import AsyncLLM from distilabel.llms.typing import GenerateOutput @@ -42,15 +43,13 @@ if TYPE_CHECKING: from huggingface_hub import AsyncInferenceClient - from openai import AsyncOpenAI from transformers import PreTrainedTokenizer class InferenceEndpointsLLM(AsyncLLM): """InferenceEndpoints LLM implementation running the async API client. - This LLM will internally use `huggingface_hub.AsyncInferenceClient` or `openai.AsyncOpenAI` - depending on the `use_openai_client` attribute. + This LLM will internally use `huggingface_hub.AsyncInferenceClient`. Attributes: model_id: the model ID to use for the LLM as available in the Hugging Face Hub, which @@ -63,7 +62,9 @@ class InferenceEndpointsLLM(AsyncLLM): tokenizer_id: the tokenizer ID to use for the LLM as available in the Hugging Face Hub. Defaults to `None`, but defining one is recommended to properly format the prompt. model_display_name: the model display name to use for the LLM. Defaults to `None`. - use_openai_client: whether to use the OpenAI client instead of the Hugging Face client. + structured_output: a dictionary containing the structured output configuration or + if more fine-grained control is needed, an instance of `OutlinesStructuredOutput`. + Defaults to None. Icon: `:hugging:` @@ -114,6 +115,29 @@ class InferenceEndpointsLLM(AsyncLLM): output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]]) ``` + + Generate structured data: + + ```python + from pydantic import BaseModel + from distilabel.llms import InferenceEndpointsLLM + + class User(BaseModel): + name: str + last_name: str + id: int + + llm = InferenceEndpointsLLM( + model_id="meta-llama/Meta-Llama-3-70B-Instruct", + tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct", + api_key="api.key", + structured_output={"format": "json", "schema": User.model_json_schema()} + ) + + llm.load() + + output = llm.generate(inputs=[[{"role": "user", "content": "Create a user profile for the Tour De France"}]]) + ``` """ model_id: Optional[str] = None @@ -137,7 +161,6 @@ class InferenceEndpointsLLM(AsyncLLM): tokenizer_id: Optional[str] = None model_display_name: Optional[str] = None - use_openai_client: bool = False structured_output: Optional[RuntimeParameter[StructuredOutputType]] = Field( default=None, @@ -149,7 +172,7 @@ class InferenceEndpointsLLM(AsyncLLM): _model_name: Optional[str] = PrivateAttr(default=None) _tokenizer: Optional["PreTrainedTokenizer"] = PrivateAttr(default=None) _api_key_env_var: str = PrivateAttr(_INFERENCE_ENDPOINTS_API_KEY_ENV_VAR_NAME) - _aclient: Optional[Union["AsyncInferenceClient", "AsyncOpenAI"]] = PrivateAttr(...) + _aclient: Optional["AsyncInferenceClient"] = PrivateAttr(...) @model_validator(mode="after") # type: ignore def only_one_of_model_id_endpoint_name_or_base_url_provided( @@ -161,11 +184,19 @@ def only_one_of_model_id_endpoint_name_or_base_url_provided( if self.base_url and (self.model_id or self.endpoint_name): self._logger.warning( # type: ignore - f"Since the `base_url={self.base_url}` is available and either one of `model_id` or `endpoint_name`" - " is also provided, the `base_url` will either be ignored or overwritten with the one generated" - " from either of those args, for serverless or dedicated inference endpoints, respectively." + f"Since the `base_url={self.base_url}` is available and either one of `model_id`" + " or `endpoint_name` is also provided, the `base_url` will either be ignored" + " or overwritten with the one generated from either of those args, for serverless" + " or dedicated inference endpoints, respectively." ) + if ( + self.model_id + and self.tokenizer_id is None + and self.structured_output is not None + ): + self.tokenizer_id = self.model_id + if self.base_url and not (self.model_id or self.endpoint_name): return self @@ -176,19 +207,16 @@ def only_one_of_model_id_endpoint_name_or_base_url_provided( return self raise ValidationError( - "Only one of `model_id` or `endpoint_name` must be provided. If `base_url` is provided too," - " it will be overwritten instead. Found `model_id`={self.model_id}, `endpoint_name`={self.endpoint_name}," - f" and `base_url`={self.base_url}." + f"Only one of `model_id` or `endpoint_name` must be provided. If `base_url` is" + f" provided too, it will be overwritten instead. Found `model_id`={self.model_id}," + f" `endpoint_name`={self.endpoint_name}, and `base_url`={self.base_url}." ) def load(self) -> None: # noqa: C901 - """Loads the either the `AsyncInferenceClient` or the `AsyncOpenAI` client to benefit - from async requests, running the Hugging Face Inference Endpoint underneath via the - `/v1/chat/completions` endpoint, exposed for the models running on TGI using the - `text-generation` task. + """Loads the `AsyncInferenceClient` client to connect to the Hugging Face Inference + Endpoint. Raises: - ImportError: if the `openai` Python client is not installed. ImportError: if the `huggingface-hub` Python client is not installed. ValueError: if the model is not currently deployed or is not running the TGI framework. ImportError: if the `transformers` Python client is not installed. @@ -236,31 +264,16 @@ def load(self) -> None: # noqa: C901 ) if client.status in ["paused", "scaledToZero"]: client.resume().wait(timeout=300) - elif client.status in ["initializing"]: + elif client.status == "initializing": client.wait(timeout=300) self.base_url = client.url self._model_name = client.repository - if self.use_openai_client: - try: - from openai import AsyncOpenAI - except ImportError as ie: - raise ImportError( - "OpenAI Python client is not installed. Please install it using" - " `pip install openai`." - ) from ie - - self._aclient = AsyncOpenAI( - base_url=self.base_url, - api_key=self.api_key.get_secret_value(), - max_retries=6, - ) - else: - self._aclient = AsyncInferenceClient( - model=self.base_url, - token=self.api_key.get_secret_value(), - ) + self._aclient = AsyncInferenceClient( + model=self.base_url, + token=self.api_key.get_secret_value(), + ) if self.tokenizer_id: try: @@ -285,113 +298,51 @@ def model_name(self) -> Union[str, None]: # type: ignore or self.base_url ) - async def _openai_agenerate( - self, - input: "StandardInput", - max_new_tokens: int = 128, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, - temperature: float = 1.0, - top_p: Optional[float] = None, - stop: Optional[Union[str, List[str]]] = None, - ) -> GenerateOutput: - """Generates completions for the given input using the OpenAI async client.""" - completion = await self._aclient.chat.completions.create( # type: ignore - messages=input, # type: ignore - model="tgi", - max_tokens=max_new_tokens, - n=1, - frequency_penalty=frequency_penalty, - presence_penalty=presence_penalty, - temperature=temperature, - top_p=top_p, - stop=stop, - timeout=50, + def prepare_input(self, input: "StandardInput") -> str: + """Prepares the input (applying the chat template and tokenization) for the provided + input. + + Args: + input: the input list containing chat items. + + Returns: + The prompt to send to the LLM. + """ + + return self._tokenizer.apply_chat_template( # type: ignore + conversation=input, # type: ignore + tokenize=False, + add_generation_prompt=True, ) - if completion.choices[0].message.content is None: - self._logger.warning( # type: ignore - f"⚠️ Received no response using OpenAI client (model: '{self.model_name}')." - f" Finish reason was: {completion.choices[0].finish_reason}" - ) - return [completion.choices[0].message.content] - @validate_call - async def agenerate( # type: ignore - self, - input: FormattedInput, - max_new_tokens: int = 128, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, - repetition_penalty: Optional[float] = None, - temperature: float = 1.0, - do_sample: bool = False, - top_k: Optional[int] = None, - top_p: Optional[float] = None, - typical_p: Optional[float] = None, - stop_sequences: Optional[Union[str, List[str]]] = None, - return_full_text: bool = False, - seed: Optional[int] = None, - watermark: bool = False, - ) -> GenerateOutput: - """Generates completions for the given input using the OpenAI async client. + def _get_structured_output( + self, input: FormattedInput + ) -> Union[Dict[str, Any], None]: + """Gets the structured output (if any) for the given input. Args: input: a single input in chat format to generate responses for. - max_new_tokens: the maximum number of new tokens that the model will generate. - Defaults to `128`. - frequency_penalty: the repetition penalty to use for the generation. Defaults - to `0.0`. Only applies if `use_openai_client=True`. - presence_penalty: the presence penalty to use for the generation. Defaults to - `0.0`. Only applies if `use_openai_client=True`. - repetition_penalty: the repetition penalty to use for the generation. Defaults - to `None`. Only applies if `use_openai_client=False`. - temperature: the temperature to use for the generation. Defaults to `1.0`. - do_sample: whether to use sampling for the generation. Defaults to `False`. - Only applies if `use_openai_client=False`. - top_k: the top-k value to use for the generation. Defaults to `0.8`, since neither - `0.0` nor `1.0` are valid values in TGI. - top_p: the top-p value to use for the generation. Defaults to `1.0`. - typical_p: the typical-p value to use for the generation. Defaults to `0.5`. - stop_sequences: either a single string or a list of strings containing the sequences - to stop the generation at. Defaults to `None`, but will be set to the - `tokenizer.eos_token` if available. - return_full_text: whether to return the full text of the completion or just the - generated text. Defaults to `False`, meaning that only the generated text will be - returned. - seed: the seed to use for the generation. Defaults to `None`. - watermark: whether to add the watermark to the generated text. Defaults to `None`. Returns: - A list of lists of strings containing the generated responses for each input. + The structured output that will be passed as `grammer` to the inference endpoint + or `None` if not required. """ - if stop_sequences is not None: - if isinstance(stop_sequences, str): - stop_sequences = [stop_sequences] - if len(stop_sequences) > 4: - warnings.warn( - "Only up to 4 stop sequences are allowed, so keeping the first 4 items only.", - UserWarning, - stacklevel=2, - ) - stop_sequences = stop_sequences[:4] - structured_output = None + + # Specific structured output per input if isinstance(input, tuple): input, structured_output = input structured_output = { - "type": structured_output["format"], - "value": structured_output["schema"], + "type": structured_output["format"], # type: ignore + "value": structured_output["schema"], # type: ignore } - # NOTE: `self.structured_output` applies to all the generations, while `structured_output` i.e. the - # value included within the tuple provided as `input` to this method, is intended to be different per - # each input, so those should not be used together. Meaning that it should be either provided at attribute - # level i.e. self, or via a column within each input i.e. row. + # Same structured output for all the inputs if structured_output is None and self.structured_output is not None: try: structured_output = { - "type": self.structured_output["format"], - "value": self.structured_output["schema"], + "type": self.structured_output["format"], # type: ignore + "value": self.structured_output["schema"], # type: ignore } except KeyError as e: raise ValueError( @@ -399,50 +350,241 @@ async def agenerate( # type: ignore "the `structured_output` attribute." ) from e - if self.use_openai_client: - return await self._openai_agenerate( - input=input, - max_new_tokens=max_new_tokens, - frequency_penalty=frequency_penalty, - presence_penalty=presence_penalty, - temperature=temperature, - top_p=top_p, - stop=stop_sequences, - ) + return structured_output - if self._tokenizer is not None: - prompt = self._tokenizer.apply_chat_template( # type: ignore - conversation=input, # type: ignore - tokenize=False, - add_generation_prompt=True, - ) - else: - # TODO: should we apply a default chat template here instead? e.g. ChatML - prompt = "\n".join([message["content"] for message in input]) + async def _generate_with_text_generation( + self, + input: FormattedInput, + max_new_tokens: int = 128, + repetition_penalty: Optional[float] = None, + frequency_penalty: Optional[float] = None, + temperature: float = 1.0, + do_sample: bool = False, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + typical_p: Optional[float] = None, + stop_sequences: Union[List[str], None] = None, + return_full_text: bool = False, + seed: Optional[int] = None, + watermark: bool = False, + ) -> Union[str, None]: + structured_output = self._get_structured_output(input) completion = None try: completion = await self._aclient.text_generation( # type: ignore - prompt=prompt, # type: ignore + prompt=self.prepare_input(input), # type: ignore max_new_tokens=max_new_tokens, do_sample=do_sample, typical_p=typical_p, repetition_penalty=repetition_penalty, + frequency_penalty=frequency_penalty, temperature=temperature, top_p=top_p, top_k=top_k, stop_sequences=stop_sequences, return_full_text=return_full_text, + # NOTE: here to ensure that the cache is not used and a different response is + # generated every time + seed=seed or random.randint(0, sys.maxsize), watermark=watermark, grammar=structured_output, # type: ignore + ) + except Exception as e: + self._logger.warning( # type: ignore + f"⚠️ Received no response using Inference Client (model: '{self.model_name}')." + f" Finish reason was: {e}" + ) + return completion + + async def _generate_with_chat_completion( + self, + input: "StandardInput", + max_new_tokens: int = 128, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[List[float]] = None, + presence_penalty: Optional[float] = None, + seed: Optional[int] = None, + stop_sequences: Optional[List[str]] = None, + temperature: float = 1.0, + tool_choice: Optional[Union[Dict[str, str], Literal["auto"]]] = None, + tool_prompt: Optional[str] = None, + tools: Optional[List[Dict[str, Any]]] = None, + top_p: Optional[float] = None, + ) -> Union[str, None]: + message = None + try: + completion = await self._aclient.chat_completion( # type: ignore + messages=input, # type: ignore + max_tokens=max_new_tokens, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + presence_penalty=presence_penalty, # NOTE: here to ensure that the cache is not used and a different response is # generated every time - seed=seed or random.randint(0, 2147483647), + seed=seed or random.randint(0, sys.maxsize), + stop=stop_sequences, + temperature=temperature, + tool_choice=tool_choice, # type: ignore + tool_prompt=tool_prompt, + tools=tools, # type: ignore + top_p=top_p, ) + choice = completion.choices[0] + if (message := choice.message.content) is None: + self._logger.warning( # type: ignore + f"⚠️ Received no response using Inference Client (model: '{self.model_name}')." + f" Finish reason was: {choice.finish_reason}" + ) except Exception as e: self._logger.warning( # type: ignore f"⚠️ Received no response using Inference Client (model: '{self.model_name}')." f" Finish reason was: {e}" ) + return message + + def _check_stop_sequences( + self, + stop_sequences: Optional[Union[str, List[str]]] = None, + ) -> Union[List[str], None]: + """Checks that no more than 4 stop sequences are provided. + + Args: + stop_sequences: the stop sequences to be checked. + + Returns: + The stop sequences. + """ + if stop_sequences is not None: + if isinstance(stop_sequences, str): + stop_sequences = [stop_sequences] + if len(stop_sequences) > 4: + warnings.warn( + "Only up to 4 stop sequences are allowed, so keeping the first 4 items only.", + UserWarning, + stacklevel=2, + ) + stop_sequences = stop_sequences[:4] + return stop_sequences - return [completion] + @validate_call + async def agenerate( # type: ignore + self, + input: FormattedInput, + max_new_tokens: int = 128, + frequency_penalty: Optional[Annotated[float, Field(ge=-2.0, le=2.0)]] = None, + logit_bias: Optional[List[float]] = None, + presence_penalty: Optional[Annotated[float, Field(ge=-2.0, le=2.0)]] = None, + seed: Optional[int] = None, + stop_sequences: Optional[List[str]] = None, + temperature: float = 1.0, + tool_choice: Optional[Union[Dict[str, str], Literal["auto"]]] = None, + tool_prompt: Optional[str] = None, + tools: Optional[List[Dict[str, Any]]] = None, + top_p: Optional[float] = None, + do_sample: bool = False, + repetition_penalty: Optional[float] = None, + return_full_text: bool = False, + top_k: Optional[int] = None, + typical_p: Optional[float] = None, + watermark: bool = False, + ) -> GenerateOutput: + """Generates completions for the given input using the async client. This method + uses two methods of the `huggingface_hub.AsyncClient`: `chat_completion` and `text_generation`. + `chat_completion` method will be used only if no `tokenizer_id` has been specified. + Some arguments of this function are specific to the `text_generation` method, while + some others are specific to the `chat_completion` method. + + Args: + input: a single input in chat format to generate responses for. + max_new_tokens: the maximum number of new tokens that the model will generate. + Defaults to `128`. + frequency_penalty: a value between `-2.0` and `2.0`. Positive values penalize + new tokens based on their existing frequency in the text so far, decreasing + model's likelihood to repeat the same line verbatim. Defauls to `None`. + logit_bias: modify the likelihood of specified tokens appearing in the completion. + This argument is exclusive to the `chat_completion` method and will be used + only if `tokenizer_id` is `None`. + Defaults to `None`. + presence_penalty: a value between `-2.0` and `2.0`. Positive values penalize + new tokens based on whether they appear in the text so far, increasing the + model likelihood to talk about new topics. This argument is exclusive to + the `chat_completion` method and will be used only if `tokenizer_id` is + `None`. Defauls to `None`. + seed: the seed to use for the generation. Defaults to `None`. + stop_sequences: either a single string or a list of strings containing the sequences + to stop the generation at. Defaults to `None`, but will be set to the + `tokenizer.eos_token` if available. + temperature: the temperature to use for the generation. Defaults to `1.0`. + tool_choice: the name of the tool the model should call. It can be a dictionary + like `{"function_name": "my_tool"}` or "auto". If not provided, then the + model won't use any tool. This argument is exclusive to the `chat_completion` + method and will be used only if `tokenizer_id` is `None`. Defaults to `None`. + tool_prompt: A prompt to be appended before the tools. This argument is exclusive + to the `chat_completion` method and will be used only if `tokenizer_id` + is `None`. Defauls to `None`. + tools: a list of tools definitions that the LLM can use. + This argument is exclusive to the `chat_completion` method and will be used + only if `tokenizer_id` is `None`. Defaults to `None`. + top_p: the top-p value to use for the generation. Defaults to `1.0`. + do_sample: whether to use sampling for the generation. This argument is exclusive + of the `text_generation` method and will be only used if `tokenizer_id` is not + `None`. Defaults to `False`. + repetition_penalty: the repetition penalty to use for the generation. This argument + is exclusive of the `text_generation` method and will be only used if `tokenizer_id` + is not `None`. Defaults to `None`. + return_full_text: whether to return the full text of the completion or just + the generated text. Defaults to `False`, meaning that only the generated + text will be returned. This argument is exclusive of the `text_generation` + method and will be only used if `tokenizer_id` is not `None`. + top_k: the top-k value to use for the generation. This argument is exclusive + of the `text_generation` method and will be only used if `tokenizer_id` + is not `None`. Defaults to `0.8`, since neither `0.0` nor `1.0` are valid + values in TGI. + typical_p: the typical-p value to use for the generation. This argument is exclusive + of the `text_generation` method and will be only used if `tokenizer_id` + is not `None`. Defaults to `None`. + watermark: whether to add the watermark to the generated text. This argument + is exclusive of the `text_generation` method and will be only used if `tokenizer_id` + is not `None`. Defaults to `None`. + + Returns: + A list of lists of strings containing the generated responses for each input. + """ + stop_sequences = self._check_stop_sequences(stop_sequences) + + if self.tokenizer_id is None: + return [ + await self._generate_with_chat_completion( + input=input, # type: ignore + max_new_tokens=max_new_tokens, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + presence_penalty=presence_penalty, + seed=seed, + stop_sequences=stop_sequences, + temperature=temperature, + tool_choice=tool_choice, + tool_prompt=tool_prompt, + tools=tools, + top_p=top_p, + ) + ] + + return [ + await self._generate_with_text_generation( + input=input, + max_new_tokens=max_new_tokens, + do_sample=do_sample, + typical_p=typical_p, + repetition_penalty=repetition_penalty, + frequency_penalty=frequency_penalty, + temperature=temperature, + top_p=top_p, + top_k=top_k, + stop_sequences=stop_sequences, + return_full_text=return_full_text, + seed=seed, + watermark=watermark, + ) + ] diff --git a/tests/unit/llms/huggingface/test_inference_endpoints.py b/tests/unit/llms/huggingface/test_inference_endpoints.py index 87a890a38c..b0e3adcbd1 100644 --- a/tests/unit/llms/huggingface/test_inference_endpoints.py +++ b/tests/unit/llms/huggingface/test_inference_endpoints.py @@ -12,21 +12,44 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import random +from typing import Generator from unittest import mock -from unittest.mock import AsyncMock, MagicMock, Mock, patch +from unittest.mock import AsyncMock, MagicMock, patch import nest_asyncio import pytest from distilabel.llms.huggingface.inference_endpoints import InferenceEndpointsLLM +from huggingface_hub import ( + ChatCompletionOutput, + ChatCompletionOutputComplete, + ChatCompletionOutputMessage, + ChatCompletionOutputUsage, +) + + +@pytest.fixture(autouse=True) +def mock_hf_token_env_variable() -> Generator[None, None, None]: + with patch.dict(os.environ, {"HF_TOKEN": "hf_token"}): + yield @patch("huggingface_hub.AsyncInferenceClient") -@patch("openai.AsyncOpenAI") class TestInferenceEndpointsLLM: - def test_load_no_api_key( - self, mock_inference_client: MagicMock, mock_openai_client: MagicMock + def test_tokenizer_id_set_if_model_id_and_structured_output( + self, mock_inference_client: MagicMock ) -> None: + llm = InferenceEndpointsLLM( + model_id="distilabel-internal-testing/tiny-random-mistral", + structured_output={"format": "regex", "schema": r"\b[A-Z][a-z]*\b"}, + ) + + assert llm.tokenizer_id == llm.model_id + + def test_load_no_api_key(self, mock_inference_client: MagicMock) -> None: + del os.environ["HF_TOKEN"] + llm = InferenceEndpointsLLM( model_id="distilabel-internal-testing/tiny-random-mistral" ) @@ -40,22 +63,19 @@ def test_load_no_api_key( ): llm.load() - def test_load_with_cached_token( - self, mock_inference_client: MagicMock, mock_openai_client: MagicMock - ) -> None: - llm = InferenceEndpointsLLM( - model_id="distilabel-internal-testing/tiny-random-mistral" - ) + def test_load_with_cached_token(self, mock_inference_client: MagicMock) -> None: + llm = InferenceEndpointsLLM(base_url="http://localhost:8000") # Mock `huggingface_hub.constants.HF_TOKEN_PATH` to exist - with mock.patch("pathlib.Path.exists", return_value=True), mock.patch( - "builtins.open", new_callable=mock.mock_open, read_data="hf_token" - ): - # Should not raise any errors - llm.load() + with mock.patch("pathlib.Path.exists", return_value=True): + with mock.patch( + "builtins.open", new_callable=mock.mock_open, read_data="hf_token" + ): + # Should not raise any errors + llm.load() def test_serverless_inference_endpoints_llm( - self, mock_inference_client: MagicMock, mock_openai_client: MagicMock + self, mock_inference_client: MagicMock ) -> None: llm = InferenceEndpointsLLM( model_id="distilabel-internal-testing/tiny-random-mistral" @@ -65,7 +85,7 @@ def test_serverless_inference_endpoints_llm( assert llm.model_name == "distilabel-internal-testing/tiny-random-mistral" def test_dedicated_inference_endpoints_llm( - self, mock_inference_client: MagicMock, mock_openai_client: MagicMock + self, mock_inference_client: MagicMock ) -> None: llm = InferenceEndpointsLLM( endpoint_name="tiny-random-mistral", @@ -76,11 +96,12 @@ def test_dedicated_inference_endpoints_llm( assert llm.model_name == "tiny-random-mistral" def test_dedicated_inference_endpoints_llm_via_url( - self, mock_inference_client: MagicMock, mock_openai_client: MagicMock + self, mock_inference_client: MagicMock ) -> None: llm = InferenceEndpointsLLM( base_url="https://api-inference.huggingface.co/models/distilabel-internal-testing/tiny-random-mistral" ) + llm.load() assert isinstance(llm, InferenceEndpointsLLM) assert ( @@ -89,13 +110,14 @@ def test_dedicated_inference_endpoints_llm_via_url( ) @pytest.mark.asyncio - async def test_agenerate_via_inference_client( - self, mock_inference_client: MagicMock, mock_openai_client: MagicMock + async def test_agenerate_with_text_generation( + self, mock_inference_client: MagicMock ) -> None: llm = InferenceEndpointsLLM( - model_id="distilabel-internal-testing/tiny-random-mistral" + model_id="distilabel-internal-testing/tiny-random-mistral", + tokenizer_id="distilabel-internal-testing/tiny-random-mistral", ) - llm._aclient = mock_inference_client + llm.load() llm._aclient.text_generation = AsyncMock( return_value=" Aenean hendrerit aliquam velit. ..." @@ -111,23 +133,38 @@ async def test_agenerate_via_inference_client( ) == [" Aenean hendrerit aliquam velit. ..."] @pytest.mark.asyncio - async def test_agenerate_via_openai_client( - self, mock_inference_client: MagicMock, mock_openai_client: MagicMock + async def test_agenerate_with_chat_completion( + self, mock_inference_client: MagicMock ) -> None: llm = InferenceEndpointsLLM( model_id="distilabel-internal-testing/tiny-random-mistral", - use_openai_client=True, ) - llm._aclient = mock_openai_client - - mocked_completion = Mock( - choices=[Mock(message=Mock(content=" Aenean hendrerit aliquam velit. ..."))] + llm.load() + + llm._aclient.chat_completion = AsyncMock( # type: ignore + return_value=ChatCompletionOutput( # type: ignore + choices=[ + ChatCompletionOutputComplete( + finish_reason="length", + index=0, + message=ChatCompletionOutputMessage( + role="assistant", + content=" Aenean hendrerit aliquam velit. ...", + ), + ) + ], + created=1721045246, + id="", + model="meta-llama/Meta-Llama-3-70B-Instruct", + system_fingerprint="2.1.1-dev0-sha-4327210", + usage=ChatCompletionOutputUsage( + completion_tokens=66, prompt_tokens=18, total_tokens=84 + ), + ) ) - llm._aclient.chat.completions.create = AsyncMock(return_value=mocked_completion) assert await llm.agenerate( input=[ - {"role": "system", "content": ""}, { "role": "user", "content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", @@ -136,48 +173,57 @@ async def test_agenerate_via_openai_client( ) == [" Aenean hendrerit aliquam velit. ..."] @pytest.mark.asyncio - async def test_generate_via_inference_client( - self, mock_inference_client: MagicMock, mock_openai_client: MagicMock + async def test_agenerate_with_chat_completion_fails( + self, mock_inference_client: MagicMock ) -> None: llm = InferenceEndpointsLLM( - model_id="distilabel-internal-testing/tiny-random-mistral" + model_id="distilabel-internal-testing/tiny-random-mistral", ) - llm._aclient = mock_inference_client - - llm._aclient.text_generation = AsyncMock( - return_value=" Aenean hendrerit aliquam velit. ..." + llm.load() + + llm._aclient.chat_completion = AsyncMock( # type: ignore + return_value=ChatCompletionOutput( # type: ignore + choices=[ + ChatCompletionOutputComplete( + finish_reason="eos_token", + index=0, + message=ChatCompletionOutputMessage( + role="assistant", + content=None, + ), + ) + ], + created=1721045246, + id="", + model="meta-llama/Meta-Llama-3-70B-Instruct", + system_fingerprint="2.1.1-dev0-sha-4327210", + usage=ChatCompletionOutputUsage( + completion_tokens=66, prompt_tokens=18, total_tokens=84 + ), + ) ) - nest_asyncio.apply() - - assert llm.generate( - inputs=[ - [ - {"role": "system", "content": ""}, - { - "role": "user", - "content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", - }, - ] + assert await llm.agenerate( + input=[ + { + "role": "user", + "content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", + }, ] - ) == [(" Aenean hendrerit aliquam velit. ...",)] + ) == [None] @pytest.mark.asyncio - async def test_generate_via_openai_client( - self, mock_inference_client: MagicMock, mock_openai_client: MagicMock - ) -> None: + async def test_generate(self, mock_inference_client: MagicMock) -> None: llm = InferenceEndpointsLLM( model_id="distilabel-internal-testing/tiny-random-mistral", - use_openai_client=True, + tokenizer_id="distilabel-internal-testing/tiny-random-mistral", ) - llm._aclient = mock_openai_client + llm.load() - mocked_completion = Mock( - choices=[Mock(message=Mock(content=" Aenean hendrerit aliquam velit. ..."))] + llm._aclient.text_generation = AsyncMock( + return_value=" Aenean hendrerit aliquam velit. ..." ) - llm._aclient.chat.completions.create = AsyncMock(return_value=mocked_completion) - ... nest_asyncio.apply() assert llm.generate( @@ -194,13 +240,14 @@ async def test_generate_via_openai_client( @pytest.mark.asyncio async def test_agenerate_with_structured_output( - self, mock_inference_client: MagicMock, _: MagicMock + self, mock_inference_client: MagicMock ) -> None: llm = InferenceEndpointsLLM( model_id="distilabel-internal-testing/tiny-random-mistral", + tokenizer_id="distilabel-internal-testing/tiny-random-mistral", structured_output={"format": "regex", "schema": r"\b[A-Z][a-z]*\b"}, ) - llm._aclient = mock_inference_client + llm.load() llm._aclient.text_generation = AsyncMock( return_value=" Aenean hendrerit aliquam velit. ..." @@ -220,29 +267,27 @@ async def test_agenerate_with_structured_output( ) == [" Aenean hendrerit aliquam velit. ..."] kwargs = { - "prompt": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", + "prompt": "[INST] Lorem ipsum dolor sit amet, consectetur adipiscing elit. [/INST]", "max_new_tokens": 128, "do_sample": False, "typical_p": None, "repetition_penalty": None, + "frequency_penalty": None, "temperature": 1.0, "top_p": None, "top_k": None, "stop_sequences": None, "return_full_text": False, + "seed": 2053695854357871005, # pre-computed random value with `random.seed(42)` "watermark": False, "grammar": {"type": "regex", "value": "\\b[A-Z][a-z]*\\b"}, - "seed": 478163327, # pre-computed random value with `random.seed(42)` } - mock_inference_client.text_generation.assert_called_with(**kwargs) + llm._aclient.text_generation.assert_called_with(**kwargs) # type: ignore - def test_serialization( - self, - mock_inference_client: MagicMock, - mock_openai_client: MagicMock, - ) -> None: + def test_serialization(self, mock_inference_client: MagicMock) -> None: llm = InferenceEndpointsLLM( model_id="distilabel-internal-testing/tiny-random-mistral", + tokenizer_id="distilabel-internal-testing/tiny-random-mistral", ) _dump = { @@ -250,11 +295,10 @@ def test_serialization( "endpoint_name": None, "endpoint_namespace": None, "base_url": None, - "tokenizer_id": None, + "tokenizer_id": "distilabel-internal-testing/tiny-random-mistral", "generation_kwargs": {}, "structured_output": None, "model_display_name": None, - "use_openai_client": False, "type_info": { "module": "distilabel.llms.huggingface.inference_endpoints", "name": "InferenceEndpointsLLM",