diff --git a/evals/registry/solvers/together.yaml b/evals/registry/solvers/together.yaml new file mode 100644 index 0000000000..a8ba40c285 --- /dev/null +++ b/evals/registry/solvers/together.yaml @@ -0,0 +1,92 @@ +# --- Direct Solvers --- +generation/direct/llama-2-13b-chat: + class: evals.solvers.together_solver:TogetherSolver + args: + completion_fn_options: + model: meta-llama/Llama-2-13b-chat-hf + extra_options: + temperature: 1 + max_tokens: 512 + postprocessors: &postprocessors + - evals.solvers.postprocessors.postprocessors:Strip + +generation/direct/llama-2-70b-chat: + class: evals.solvers.together_solver:TogetherSolver + args: + completion_fn_options: + model: meta-llama/Llama-2-70b-chat-hf + extra_options: + temperature: 1 + max_tokens: 512 + postprocessors: *postprocessors + +generation/direct/mixtral-8x7b-instruct: + class: evals.solvers.together_solver:TogetherSolver + args: + completion_fn_options: + model: mistralai/Mixtral-8x7B-Instruct-v0.1 + extra_options: + temperature: 1 + max_tokens: 512 + postprocessors: *postprocessors +# --- COT Solvers --- + +generation/cot/llama-2-13b-chat: + class: evals.solvers.nested.cot_solver:CoTSolver + args: + cot_solver: + class: evals.solvers.together_solver:TogetherSolver + args: + completion_fn_options: + model: meta-llama/Llama-2-13b-chat-hf + extra_options: + temperature: 1 + max_tokens: 512 + extract_solver: + class: evals.solvers.together_solver:TogetherSolver + args: + completion_fn_options: + model: meta-llama/Llama-2-13b-chat-hf + extra_options: + temperature: 1 + max_tokens: 512 + +generation/cot/llama-2-70b-chat: + class: evals.solvers.nested.cot_solver:CoTSolver + args: + cot_solver: + class: evals.solvers.together_solver:TogetherSolver + args: + completion_fn_options: + model: meta-llama/Llama-2-70b-chat-hf + extra_options: + temperature: 1 + max_tokens: 512 + extract_solver: + class: evals.solvers.together_solver:TogetherSolver + args: + completion_fn_options: + model: meta-llama/Llama-2-70b-chat-hf + extra_options: + temperature: 1 + max_tokens: 512 + +generation/cot/mixtral-8x7b-instruct: + class: evals.solvers.nested.cot_solver:CoTSolver + args: + cot_solver: + class: evals.solvers.together_solver:TogetherSolver + args: + completion_fn_options: + model: mistralai/Mixtral-8x7B-Instruct-v0.1 + extra_options: + temperature: 1 + max_tokens: 512 + extract_solver: + class: evals.solvers.together_solver:TogetherSolver + args: + completion_fn_options: + model: mistralai/Mixtral-8x7B-Instruct-v0.1 + extra_options: + temperature: 1 + max_tokens: 512 \ No newline at end of file diff --git a/evals/solvers/openai_solver.py b/evals/solvers/openai_solver.py index efe75d01ee..3d0c4dd700 100644 --- a/evals/solvers/openai_solver.py +++ b/evals/solvers/openai_solver.py @@ -20,7 +20,13 @@ class OpenAISolver(Solver): - """A solver class that uses the OpenAI API through completion functions.""" + """ + A solver class for OpenAI models that uses the OpenAI python SDK. + + Note: this class is also inherited by + `evals.solvers.together_solver.TogetherSolver`, which uses the same OpenAI python + SDK. + """ def __init__( self, @@ -33,6 +39,7 @@ def __init__( registry: Any = None, ): super().__init__(postprocessors=postprocessors) + self.valid_answers = valid_answers self.completion_fn_options = completion_fn_options # Additional options for base model self.fixed_start = fixed_start @@ -43,57 +50,174 @@ def __init__( raise ValueError("OpenAISolver requires a model to be specified.") model = completion_fn_options["model"] - # Infer suitable CompletionFn class from the model name - if is_chat_model(model): - completion_fn_cls = OpenAIChatCompletionFn - if self.fixed_start is not None or self.continue_last_assistant_msg: - raise ValueError( - "OpenAISolver does not support fixed_start or continue_last_assistant_msg with chat models." - ) - else: - if self.fixed_start is not None and self.continue_last_assistant_msg: - raise ValueError( - "OpenAISolver does not support both fixed_start and continue_last_assistant_msg being used." - ) + completion_fn_cls = self._get_completion_fn_cls(model) - completion_fn_cls = OpenAICompletionFn - - # If valid answers were provided, apply logit bias to those tokens - if valid_answers is not None and len(valid_answers) > 0: - self.completion_fn_options["extra_options"]["logit_bias"] = self._make_logit_bias( - valid_answers, - model, - ) + self._preprocess_completion_fn_options() # Create the completion function self.completion_fn = completion_fn_cls( + api_base=self._api_base, + api_key=self._api_key, **self.completion_fn_options, ) @property def model(self) -> str: + """ + Get model name from completion function, e.g. "gpt-3.5-turbo" + This may not always include the full model version, e.g. "gpt-3.5-turbo-0613" + so use `self.model_version` if you need the exact snapshot. + """ return self.completion_fn.model + def name(self) -> str: + return self.completion_fn.model + + @property + def model_version(self) -> Union[str, dict]: + """ + Makes dummy API request to get exact model version from the API + e.g. "gpt-3.5-turbo-0613" + """ + dummy_task_state = TaskState("", "") + solver_result = self(dummy_task_state, **{"max_tokens": 1}) + raw_data = solver_result._metadata["raw_completion_result"].raw_data + return raw_data.model + + def _is_chat_model(self, model: str) -> bool: + """ + Checks in the registry if the model is a chat model. + Implemented as a method to allow for overriding in subclasses + (e.g. TogetherSolver, which uses a different registry of chat models) + """ + return is_chat_model(model) + @property - def is_completion_model(self) -> bool: - return not is_chat_model(self.model) + def _completion_exception(self) -> Exception: + """ + Returns the exception to handle when the completion function fails + via self._handle_completion_exception + """ + return BadRequestError - def _make_logit_bias(self, valid_answers: list[str], model: str) -> dict[int, float]: - if model == "code-davinci-002": - logging.info( - f"Attempting to use logit bias with model {model}, which does not support logit bias." - ) + @property + def _api_base(self) -> Optional[str]: + """The base URL for the API""" + # by default, None, which points to the default API Base which is the OpenAI API + return None - enc = tiktoken.encoding_for_model(model) - token_ids = [] - for answer in valid_answers: - encoded_answer = enc.encode(answer) - if len(encoded_answer) > 1: - raise ValueError( - f"Answer {answer} was encoded to {encoded_answer}, but we expected a single token." + @property + def _api_key(self) -> Optional[str]: + """The API key to use for the API""" + # by default, None, which points to the default API Key which is "OPENAI_API_KEY" + return None + + def _solve(self, task_state: TaskState, **kwargs) -> SolverResult: + raw_msgs = [ + {"role": "system", "content": task_state.task_description}, + ] + [msg.to_dict() for msg in task_state.messages] + + precheck_outcome = self._perform_prechecks(raw_msgs) + if precheck_outcome is not None: + return precheck_outcome + + msgs = self._process_msgs(raw_msgs) + + try: + if self._is_chat_model(self.model): + completion_result = self.completion_fn(prompt=msgs, **kwargs) + + completion_output = completion_result.get_completions()[0] + + # Chat model output is already parsed, just return it + solver_result = SolverResult( + completion_output, raw_completion_result=completion_result ) - token_ids.append(encoded_answer[0]) - return {token_id: 100 for token_id in token_ids} + else: + # Manually render the prompt for completion models so that we can + # implement things like custom render formats and/or fixed_start + prompt = self._render_completion_prompt(msgs) + + stop_sequences = self._get_msg_separators() + if len(stop_sequences) > 4: + logging.warn("Using more than 4 stop sequences is unsupported") + completion_result = self.completion_fn(prompt=prompt, stop=stop_sequences, **kwargs) + + completion_output = completion_result.get_completions()[0] + + # Completion model output needs to be parsed to remove role prefixes + solver_result = SolverResult( + self._parse_completion_response(completion_output), + raw_output=completion_output, + raw_completion_result=completion_result, + ) + except self._completion_exception as e: + solver_result = self._handle_completion_exception(e) + + return solver_result + + def _perform_prechecks(self, msgs: list[dict[str, str]]) -> Optional[SolverResult]: + """ + Check if the prompt exceeds the context length before querying the + API to avoid it contributing to the tokens per minute (TPM) limit + + If `None` is returned, the prompt is within the context length. + """ + enc = tiktoken.encoding_for_model(self.model) + ctx_len = n_ctx_from_model_name(self.model) + n_tokens = 0 + + for msg in msgs: + tokens = enc.encode(msg["content"]) + n_tokens += len(tokens) + + if ctx_len is not None and n_tokens >= ctx_len: + return SolverResult( + output=f"Request too large for {self.model}. Context length: {ctx_len} tokens. Requested: {n_tokens} tokens.", + ) + + return None + + def _process_msgs(self, raw_msgs: list[dict[str, str]]) -> list[dict[str, str]]: + """ + Perform any message processing before querying the API + e.g. converting 'system' roles to 'user' roles + """ + # By default, no message processing is performed, but subclasses can override this + return raw_msgs + + def _handle_completion_exception(self, e: Exception) -> SolverResult: + """ + Handles any expected exceptions from the completion function: + - context_length_exceeded: The prompt exceeds the context length + - too many messages: The prompt has too many messages + + Raises any other exceptions + """ + if ( + e.code == "context_length_exceeded" + or "Please reduce your prompt; or completion length" + in e.message # For context length errors where code is not specified. + ): + logging.warn( + f"OpenAI API context length exceeded, using error message as solver response: {e.message}" + ) + solver_result = SolverResult( + e.message, + error=e.body, + ) + elif "'$.messages' is too long" in e.message: # If we have too many messages + logging.warn( + f"Exceeded maximum chat messages on OpenAI API, using error message as solver response: {e.message}" + ) + solver_result = SolverResult( + e.message, + error=e.body, + ) + else: + raise e + + return solver_result def _render_completion_prompt(self, msgs: list[dict[str, str]]) -> str: # Render messages as a chat dialogue in plaintext (also postfixes "Assistant: " to tee up the model) @@ -129,94 +253,45 @@ def _get_msg_separators(self) -> list[str]: """ return [v.strip() for v in self.role_to_prefix.values() if v.strip() != ""] - def _solve( - self, - task_state: TaskState, - **kwargs, - ) -> SolverResult: - msgs = [ - {"role": "system", "content": task_state.task_description}, - ] + [msg.to_dict() for msg in task_state.messages] - - # Check if the prompt exceeds the context length before querying the - # API to avoid it contributing to the tokens per minute (TPM) limit - enc = tiktoken.encoding_for_model(self.model) - ctx_len = n_ctx_from_model_name(self.model) - n_tokens = 0 - - for msg in msgs: - tokens = enc.encode(msg["content"]) - n_tokens += len(tokens) - - if ctx_len is not None and n_tokens >= ctx_len: - return SolverResult( - output=f"Request too large for {self.model}. Context length: {ctx_len} tokens. Requested: {n_tokens} tokens.", - ) - - try: - if self.is_completion_model: - # Manually render the prompt for completion models so that we can - # implement things like custom render formats and/or fixed_start - prompt = self._render_completion_prompt(msgs) - - stop_sequences = self._get_msg_separators() - if len(stop_sequences) > 4: - logging.warn("Using more than 4 stop sequences is unsupported") - completion_result = self.completion_fn(prompt=prompt, stop=stop_sequences, **kwargs) - - completion_output = completion_result.get_completions()[0] - - # Completion model output needs to be parsed to remove role prefixes - solver_result = SolverResult( - self._parse_completion_response(completion_output), - raw_output=completion_output, - raw_completion_result=completion_result, - ) - else: - completion_result = self.completion_fn(prompt=msgs, **kwargs) - - completion_output = completion_result.get_completions()[0] - - # Chat model output is already parsed, just return it - solver_result = SolverResult( - completion_output, raw_completion_result=completion_result - ) - except BadRequestError as e: - if ( - e.code == "context_length_exceeded" - or "Please reduce your prompt; or completion length" - in e.message # For context length errors where code is not specified. - ): - logging.warn( - f"OpenAI API context length exceeded, using error message as solver response: {e.message}" - ) - solver_result = SolverResult( - e.message, - error=e.body, - ) - elif "'$.messages' is too long" in e.message: # If we have too many messages - logging.warn( - f"Exceeded maximum chat messages on OpenAI API, using error message as solver response: {e.message}" + def _get_completion_fn_cls(self, model: str) -> Any: + # Infer suitable CompletionFn class from the model name + if self._is_chat_model(model): + completion_fn_cls = OpenAIChatCompletionFn + if self.fixed_start is not None or self.continue_last_assistant_msg: + raise ValueError( + "OpenAISolver does not support fixed_start or continue_last_assistant_msg with chat models." ) - solver_result = SolverResult( - e.message, - error=e.body, + else: + if self.fixed_start is not None and self.continue_last_assistant_msg: + raise ValueError( + "OpenAISolver does not support both fixed_start and continue_last_assistant_msg being used." ) - else: - raise e - return solver_result + completion_fn_cls = OpenAICompletionFn - @property - def name(self) -> str: - return self.completion_fn.model + return completion_fn_cls - @property - def model_version(self) -> Union[str, dict]: + def _preprocess_completion_fn_options(self) -> dict: """ - Makes dummy API request to get exact snapshot + Preprocess the completion function options before creating the completion function + + e.g. apply logit biasing """ - dummy_task_state = TaskState("", "") - solver_result = self(dummy_task_state, **{"max_tokens": 1}) - raw_data = solver_result._metadata["raw_completion_result"].raw_data - return raw_data.model + model = self.completion_fn_options["model"] + # If valid answers were provided, apply logit bias to those tokens + if self.valid_answers is not None and len(self.valid_answers) > 0: + self.completion_fn_options["extra_options"]["logit_bias"] = self._make_logit_bias( + self.valid_answers, model + ) + + def _make_logit_bias(self, valid_answers: list[str], model: str) -> dict[int, float]: + enc = tiktoken.encoding_for_model(model) + token_ids = [] + for answer in valid_answers: + encoded_answer = enc.encode(answer) + if len(encoded_answer) > 1: + raise ValueError( + f"Answer {answer} was encoded to {encoded_answer}, but we expected a single token." + ) + token_ids.append(encoded_answer[0]) + return {token_id: 100 for token_id in token_ids} diff --git a/evals/solvers/together_solver.py b/evals/solvers/together_solver.py new file mode 100644 index 0000000000..8756d17cf5 --- /dev/null +++ b/evals/solvers/together_solver.py @@ -0,0 +1,149 @@ +import os +import copy +import logging +from typing import Optional + +from openai import PermissionDeniedError + +from evals.solvers.solver import SolverResult +from evals.solvers.openai_solver import OpenAISolver + + +def is_chat_model(model: str) -> bool: + # NOTE: this is just as brittle as evals.registry.is_chat_model + # that we use for OpenAI models + if model in { + "meta-llama/Llama-2-13b-chat-hf", + "meta-llama/Llama-2-70b-chat-hf", + "mistralai/Mixtral-8x7B-Instruct-v0.1", + }: + return True + elif model in {}: + return False + else: + raise NotImplementedError( + f"Model {model} not currently supported by TogetherSolver" + ) + + +class TogetherSolver(OpenAISolver): + """ + A solver class for the Together API via the OpenAI python SDK completion functions. + Leveraging the OpenAISolver class, with some overrides. + + Specifically we override: + - `_api_base` to point to the Together API + - `_api_key` to use the TOGETHER_API_KEY environment variable + - `_is_chat_model` to use a different dictionary of supported chat models + - `_preprocess_completion_fn_options` to not perform any completion fn options preprocessing + - `_perform_prechecks` to not perform any checks before calling the API + - `_process_msgs` to convert message roles to comply with the Together API + - `_completion_exception` to use the Together API's error code for context length + - `_handle_completion_exception` to handle Together API errors differently + + Additionally, the `valid_answers` parameter is not supported by the Together API + """ + + def __init__(self, merge_adjacent_msgs: bool = False, **kwargs): + super().__init__(**kwargs) + self.merge_adjacent_msgs = merge_adjacent_msgs + if self.valid_answers is not None: + raise NotImplementedError("`valid_answers` not supported by TogetherSolver") + + @property + def _api_base(self) -> Optional[str]: + """The base URL for the API""" + return "https://api.together.xyz/v1" + + @property + def _api_key(self) -> Optional[str]: + """The API key to use for the API""" + return os.environ.get("TOGETHER_API_KEY") + + @property + def _completion_exception(self) -> Exception: + """ + Overrides OpenAISolver implementation; + Together API uses a different error code to signal context length issues + """ + return PermissionDeniedError + + def _is_chat_model(self, model: str) -> bool: + """ + Overrides OpenAISolver implementation; + Need to use different dictionary of chat models + """ + return is_chat_model(model) + + def _preprocess_completion_fn_options(self) -> dict: + """ + Overrides OpenAISolver implementation; Here we do not perform any completion fn + options preprocessing since the TogetherSolver does not support the + `valid_answers` parameter + """ + pass + + def _perform_prechecks(self, msgs: list[dict[str, str]]) -> Optional[SolverResult]: + """ + Overrides OpenAISolver implementation; Here we do not perform any prechecks + since the TogetherSolver does not support context length checks due to the lack + of a tokenizer. + """ + return None + + def _process_msgs(self, msgs: list[dict[str, str]]) -> list[dict[str, str]]: + """ + Many OS models, like Llama-2 and Mixtral, expect a more specific format than + we often provide to OpenAI models. In particular + - there should only be a single system prompt, at the start + - there should be at least one user prompt + - after an optional system prompt, the messages should alternate between + user and assistant messages. + """ + msgs = copy.deepcopy(msgs) + + # if there is only a system message, turn it to a user message + if len(msgs) == 1 and msgs[0]["role"] == "system": + return [{"role": "user", "content": msgs[0]["content"]}] + + # convert all system messages except a possible first one to user messages + for i, msg in enumerate(msgs): + if msg["role"] == "system" and i > 0: + msg["role"] = "user" + + # if the first message is a system message and the second one is an assistant message, + # this implies that we previously converted the initial system message to a user message, + # so we should convert the initial system message to a user message again for consistency + # NOTE: this looks like it'd fail on length 1 messages, but that's handled by the first if + # combined with the first statement of this if and lazy evaluation + if msgs[0]["role"] == "system" and msgs[1]["role"] == "assistant": + msgs[0]["role"] = "user" + + # before returning, we optionally merge all adjacent messages from the same role + if self.merge_adjacent_msgs: + merged_msgs = [] + for msg in msgs: + if len(merged_msgs) > 0 and merged_msgs[-1]["role"] == msg["role"]: + merged_msgs[-1]["content"] += "\n\n" + msg["content"] + else: + merged_msgs.append(msg) + msgs = merged_msgs + return msgs + + def _handle_completion_exception(self, e: Exception) -> SolverResult: + """ + Overrides OpenAISolver implementation; TogetherSolver is a bit less granular + and the errors are parsed differently. + """ + if e.type == "invalid_request_error": + logging.warn( + f"Together API context length exceeded, using error message as solver response: {e.message}" + ) + solver_result = SolverResult( + e.message, + error=e.body, + ) + else: + raise e + + return solver_result diff --git a/evals/solvers/together_solver_test.py b/evals/solvers/together_solver_test.py new file mode 100644 index 0000000000..d98de8f68f --- /dev/null +++ b/evals/solvers/together_solver_test.py @@ -0,0 +1,117 @@ +import pytest + +from evals.solvers.together_solver import TogetherSolver + + +@pytest.fixture +def llama_solver(): + solver = TogetherSolver( + completion_fn_options={ + "model": "meta-llama/Llama-2-13b-chat-hf", + }, + ) + return solver + + +@pytest.fixture +def llama_solver_merge(): + solver = TogetherSolver( + merge_adjacent_msgs=True, + completion_fn_options={ + "model": "meta-llama/Llama-2-13b-chat-hf", + }, + ) + return solver + + +def test_single_system_msg(llama_solver): + in_msgs = [ + {"role": "system", "content": "Hello"}, + ] + out_msgs = [ + {"role": "user", "content": "Hello"}, + ] + assert llama_solver._process_msgs(in_msgs) == out_msgs + + +def test_system_assistant_msgs(llama_solver): + in_msgs = [ + {"role": "system", "content": "Hello"}, + {"role": "assistant", "content": "Hi, how are ya?"}, + ] + out_msgs = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi, how are ya?"}, + ] + assert llama_solver._process_msgs(in_msgs) == out_msgs + + +def test_system_user_msg(llama_solver): + in_msgs = [ + {"role": "system", "content": "Hello"}, + {"role": "user", "content": "Hi, how are ya?"}, + ] + out_msgs = [ + {"role": "system", "content": "Hello"}, + {"role": "user", "content": "Hi, how are ya?"}, + ] + assert llama_solver._process_msgs(in_msgs) == out_msgs + + +def test_final_system_msg(llama_solver): + in_msgs = [ + {"role": "system", "content": "Hello"}, + {"role": "user", "content": "Hi, how are ya?"}, + {"role": "system", "content": "Good, you?"}, + ] + out_msgs = [ + {"role": "system", "content": "Hello"}, + {"role": "user", "content": "Hi, how are ya?"}, + {"role": "user", "content": "Good, you?"}, + ] + assert llama_solver._process_msgs(in_msgs) == out_msgs + + +def test_combined(llama_solver): + in_msgs = [ + {"role": "system", "content": "Hello"}, + {"role": "assistant", "content": "Hi, how are ya?"}, + {"role": "system", "content": "Good, you?"}, + ] + out_msgs = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi, how are ya?"}, + {"role": "user", "content": "Good, you?"}, + ] + assert llama_solver._process_msgs(in_msgs) == out_msgs + + +def test_merge(llama_solver_merge): + in_msgs = [ + {"role": "system", "content": "Hello"}, + {"role": "user", "content": "Hi, how are ya?"}, + {"role": "user", "content": "Good, you?"}, + ] + out_msgs = [ + {"role": "system", "content": "Hello"}, + {"role": "user", "content": "Hi, how are ya?\n\nGood, you?"}, + ] + assert llama_solver_merge._process_msgs(in_msgs) == out_msgs + + +def test_advanced_merge(llama_solver_merge): + in_msgs = [ + {"role": "system", "content": "Hello"}, + {"role": "user", "content": "Hi, how are ya?"}, + {"role": "user", "content": "Good, you?"}, + {"role": "assistant", "content": "Message 1"}, + {"role": "assistant", "content": "Message 2"}, + {"role": "user", "content": "Message 3"}, + ] + out_msgs = [ + {"role": "system", "content": "Hello"}, + {"role": "user", "content": "Hi, how are ya?\n\nGood, you?"}, + {"role": "assistant", "content": "Message 1\n\nMessage 2"}, + {"role": "user", "content": "Message 3"}, + ] + assert llama_solver_merge._process_msgs(in_msgs) == out_msgs