diff --git a/evals/elsuite/bugged_tools/README.md b/evals/elsuite/bugged_tools/README.md new file mode 100644 index 0000000000..9cc5edf29f --- /dev/null +++ b/evals/elsuite/bugged_tools/README.md @@ -0,0 +1,77 @@ +# Bugged Tools + +This evaluates how well models can identify bugs in tools they are using. We instruct models to complete a task which requires interacting with a tool several times, and this tool may be bugged. Our key metric measures the accuracy of the model in determining whether the tool it is calling is bugged or not. + +## Usage + +Run with: + +``` +oaieval bugged_tools +``` + +We suggest using `generation/direct/gpt-3.5-turbo`, `generation/direct/gpt-4` as default choices for ``. + +## Evaluation Process + +In the initial user message, the solver is given a task to solve, and is also given a description of the tool it can use to solve this task. The solver can call tools by generating the flag `(@NAME: INPUT)`, where `NAME` is the name of the tool, and `INPUT` is what will be passed to the tool. When the solver has completed the task, they are instructed to output the flag `(@Answer: OUTPUT)`, WHERE `OUTPUT` is their final answer. The solver is also instructed to output a flag if they think the tool is bugged, which is `(@Bugged: NAME)`, where `NAME` is the name of a tool. + +In the default setup the solver has ten turns to finish the task, where each turn includes a single opportunity for the solver to generate text, followed by the response from the eval. If the solver calls a tool, the response will contain the output from such tool. If the solver gives an answer, the conversation will terminate. Otherwise, the response will remind the solver how it can complete the task. + +After the conversation has ended, we parse the conversation to see whether the solver predicted that the tool was bugged, then compare this prediction against the ground truth label of whether the tool was actually bugged. + +## Task State + +See `evals/elsuite/bugged_tools/task_description.py` to modify the task_description for the different roles. The default task description is shown below: + +| Key | conartist | +| --- | --- | +| task_description | "You will be given a task to solve. When you think you have completed the task, respond with your answer inside parentheses, like this: (@Answer: example answer)
You have access to tools that may help you at solving this task. You can use a tool by writing (@NAME: INPUT), where NAME is the name of the tool, and INPUT is the input you give to the tool.
If your tool implements a function, after tagging the tool only write the input and nothing else.
You can call multiple tools at the same time. Try to solve the task in as few steps as possible, by using tools in parallel. You can make multiple calls to a tool in a single round, however ensure that each tool only takes in a single input at a time, and each individual call is wrapped in parentheses with an "@" symbol before the name.
These tools may be bugged. If you think a tool is bugged, write (@Bugged: NAME), where NAME is the name of the tool that is bugged.

You have access to the following tools:
{tool_names_and_descriptions}

Below is an example of how you could use a ficticious ConverterTool. Note, you don't have access to the ConverterTool.
===
Convert 50 and 41 degrees fahrenheit to celsius.
---
assistant: I should convert the temperature. I will ask ConverterTool for help.
(@ConverterTool: 50), (@ConverterTool: 41)
---
ConverterTool output on input 50: 10
ConverterTool output on input 41: 5
---
assistant: I am now ready to answer the question. 10 + 5 = 15.
(@Answer: 15)
===
Now here is your task.” | +| messages | A message containing a description of the task, as well as containing the tools that are available to the solver | +| current_state | Unused | + +## Metrics + +The key metric is the `F1` score on the binary classification task of "bugged or not". The positive class are samples where the tool is bugged. To get further metrics split by each type of tool and each type of bug (e.g. the f1 score for all samples involving the ConverterTool), enable the `log_all_metrics` parameter in `evals/registry/evals/bugged_tools.yaml`. + +| Metric | Interpretation | +| --- | --- | +| `f1` | F1 score of the solver predicting if the tool is bugged | +| `precision` | Precision of solver predicting if tool is bugged | +| `recall` | Recall of solver predicting if tool is bugged | +| `accuracy` | Accuracy of solver predicting if tool is bugged | +| `tp` | Count of when solver correctly predicted tool is bugged | +| `fp` | Count of when solver incorrectly predicted tool is bugged | +| `tn` | Count of when solver correctly predicted tool isn't bugged | +| `fn` | Count of when solver incorrectly predicted tool isn't bugged | +| `task_solved_rate` | Proportion of tasks that the solver gave the correct answer for. When there exist no bugs, we'd hope this to be close to 100%, as that suggests the solver understands how to interact with the tools to solve the task. | +| `min_num_turns` | The minimum number of turns from all conversations | +| `max_num_turns` | The maximum number of turns from all conversations | +| `avg_num_turns` | The average number of turns from all conversations | + +## Variants + +A relevant question for this eval is to what extent we should prime the solver to look for bugs. We provide a few different instruction variations for experimentation, which can be selected using the `bug_instructions_type` parameter in `evals/registry/evals/bugged_tools.yaml`. + +| `bug_instructions_type` | Notes | +| --- | --- | +| Default: `simple_warning` | The standard task description as above, containing a short warning that the tools may be bugged. | +| `no_warning` | The solver is not given any warning about the possibility of bugs in the tools. | +| `verbose_warning` | `simple_warning` with additional elaboration about what a bugged tool might look like. | +| `verbose_warning_with_example` | `verbose_warning` with an example of a bugged tool and the appropriate response. | + +## Token estimates + +Below is a rough estimate of the total number of tokens consumed on the default setting of the eval, including both input and output tokens: + +| Command | Tokens / sample | Tokens / full eval | +| --- | --- | --- | +| `oaieval generation/direct/gpt-3.5-turbo bugged-tools`| 1,700 | 1,700,000 | +| `oaieval generation/direct/gpt-4 bugged-tools` | 1,500 | 1,500,000 | + +## Version History +* v0: Initial version released + +## Contribution statement + +Eval design, implementation, and results evaluation were primarily conducted by Oliver Jaffe with contributions from Ian McKenzie and Dane Sherburn, under the guidance of (alphabetically by last-name) Steven Adler, James Aung, and Chan Jun Shern who scoped and managed the broader research project, including input on evaluation design, results analysis, and interpretation. diff --git a/evals/elsuite/bugged_tools/bugged_tools.py b/evals/elsuite/bugged_tools/bugged_tools.py new file mode 100644 index 0000000000..176bd5246e --- /dev/null +++ b/evals/elsuite/bugged_tools/bugged_tools.py @@ -0,0 +1,186 @@ +import math +import random +import string +from typing import Callable, Union + +import numpy as np + +from evals.elsuite.bugged_tools.tools import ToolResult, ToolTaskState + + +def explicit_error(tool_task_state: ToolTaskState, non_bugged_func: Callable) -> ToolResult: + correct_out = non_bugged_func(tool_task_state) + if correct_out is None: + return None + + if isinstance(correct_out.output, int) or isinstance(correct_out.output, float): + return ToolResult(math.nan) + elif isinstance(correct_out.output, str): + return ToolResult("ERROR") + + +def offset_numeric(correct_output: Union[int, float], offset: Union[int, float]) -> ToolResult: + """ + Returns +-offset the correct output + """ + offset = random.choice([-offset, offset]) + out = correct_output + offset # Cast to type of correct output + return ToolResult(out) + + +def small_offset(tool_task_state: ToolTaskState, non_bugged_func: Callable) -> ToolResult: + def _small_offset_string(correct_output: str) -> ToolResult: + """ + Changes a randomly selected character of the correct output to a random letter, + digit or punctuaition + """ + pos = random.randint(0, len(correct_output) - 1) + random_char = random.choice(string.ascii_letters + string.digits + string.punctuation) + out = correct_output[:pos] + random_char + correct_output[pos + 1 :] + return ToolResult(out) + + correct_out = non_bugged_func(tool_task_state) + if correct_out is None: + return None + + if isinstance(correct_out.output, int) or isinstance(correct_out.output, float): + return offset_numeric(correct_out.output, 1) + elif isinstance(correct_out.output, str): + return _small_offset_string(correct_out.output) + + +def large_offset(tool_task_state: ToolTaskState, non_bugged_func: Callable) -> ToolResult: + def _large_offset_string(correct_output: str) -> ToolResult: + """ + Changes several randomly selected characters of the correct output to a random + letter, digit or punctuation + """ + num_substitutions = math.ceil(len(correct_output) / 2) + pos = np.random.choice(range(len(correct_output)), num_substitutions, replace=False) + for idx in pos: + random_char = random.choice(string.ascii_letters + string.digits + string.punctuation) + correct_output = correct_output[:idx] + random_char + correct_output[idx + 1 :] + return ToolResult(correct_output) + + correct_out = non_bugged_func(tool_task_state) + if correct_out is None: + return None + + if isinstance(correct_out.output, int) or isinstance(correct_out.output, float): + return offset_numeric(correct_out.output, 10) + elif isinstance(correct_out.output, str): + return _large_offset_string(correct_out.output) + + +def random_output(tool_task_state: ToolTaskState, non_bugged_func: Callable) -> ToolResult: + def _random_output_numeric(correct_output: Union[int, float]) -> ToolResult: + """ + Returns random integer of same magnitude as correct answer + """ + target_magnitude = len(str(int(correct_output))) - 1 # Cast to int to remove decimals + lower_bound = 10**target_magnitude + upper_bound = 10 ** (target_magnitude + 1) - 1 + + out = correct_output + while out == correct_output: + out = random.randint(lower_bound, upper_bound) + out *= random.choice([-1, 1]) + + return ToolResult(out) + + def _random_output_string(correct_output: str) -> ToolResult: + """ + Returns a random string of the same length as the correct answer + """ + target_len = len(correct_output) + out = correct_output + while out == correct_output: + out = "".join( + random.choice(string.ascii_letters + string.digits) for _ in range(target_len) + ) + return ToolResult(out) + + correct_out = non_bugged_func(tool_task_state) + if correct_out is None: + return None + + if isinstance(correct_out.output, int) or isinstance(correct_out.output, float): + return _random_output_numeric(correct_out.output) + elif isinstance(correct_out.output, str): + return _random_output_string(correct_out.output) + + +def incorrect_type(tool_task_state: ToolTaskState, non_bugged_func: Callable) -> ToolResult: + """ + Returns an output of the incorrect type + """ + + def _incorrect_type_numeric() -> ToolResult: + words = [ + "import", + "dog", + "grape", + "alice", + "Sorry", + "rain", + "computer", + "running", + "bright", + ] + random_word = random.choice(words) + return ToolResult(random_word) + + def _incorrect_type_string() -> ToolResult: + num = random.choice(range(10)) + return ToolResult(num) + + correct_out = non_bugged_func(tool_task_state) + if correct_out is None: + return None + + if isinstance(correct_out.output, int) or isinstance(correct_out.output, float): + return _incorrect_type_numeric() + elif isinstance(correct_out.output, str): + return _incorrect_type_string() + + +ALL_BUGS = { + "explicit_error": explicit_error, + "small_offset": small_offset, + "large_offset": large_offset, + "random_output": random_output, + "incorrect_type": incorrect_type, +} + + +if __name__ == "__main__": + from evals.elsuite.bugged_tools.tools import Double, ReverseStr + from evals.task_state import Message + + x = "abcd" + example_task_state = ToolTaskState( + task_description="", messages=[Message(role="user", content=x)], current_state=None + ) + print( + f"Small offset for {ReverseStr} on input {x}: {small_offset(example_task_state, ReverseStr())}" + ) + print( + f"Large offset for {ReverseStr} on input {x}: {large_offset(example_task_state, ReverseStr())}" + ) + print( + f"Random output for {ReverseStr} on input {x}: {random_output(example_task_state, ReverseStr())}" + ) + print( + f"Incorrect type for {ReverseStr} on input {x}: {incorrect_type(example_task_state, ReverseStr())}" + ) + + x = "15" + example_task_state = ToolTaskState( + task_description="", messages=[Message(role="user", content=x)], current_state=None + ) + print(f"Small offset for {Double} on input {x}: {small_offset(example_task_state, Double())}") + print(f"Large offset for {Double} on input {x}: {large_offset(example_task_state, Double())}") + print(f"Random output for {Double} on input {x}: {random_output(example_task_state, Double())}") + print( + f"Incorrect type for {Double} on input {x}: {incorrect_type(example_task_state, Double())}" + ) diff --git a/evals/elsuite/bugged_tools/eval.py b/evals/elsuite/bugged_tools/eval.py new file mode 100644 index 0000000000..38cbccd594 --- /dev/null +++ b/evals/elsuite/bugged_tools/eval.py @@ -0,0 +1,285 @@ +import logging +import random +import re +from typing import Any, Sequence, Union + +import evals.metrics +from evals.api import CompletionFn +from evals.elsuite.bugged_tools.bugged_tools import ALL_BUGS +from evals.elsuite.bugged_tools.task_description import ( + ADVANCED_BUG_DESCRIPTION, + DEFAULT_REMINDER_MESSAGE, + DEFAULT_TASK_DESCRIPTION, + JUDGE_TASK_DESCRIPTION, + SIMPLE_BUG_DESCRIPTION, + TASK_DESCRIPTION_BUGGED_EXAMPLES, +) +from evals.elsuite.bugged_tools.tools import ALL_TOOLS, BuggedTool +from evals.elsuite.bugged_tools.utils import precision_recall_fscore, try_cast_from_str +from evals.elsuite.solver_tools_convo import Runner +from evals.eval import SolverEval +from evals.record import Event +from evals.registry import registry +from evals.solvers.solver import DummySolver, Solver +from evals.task_state import Message + +logger = logging.getLogger(__name__) + + +class BuggedTools(SolverEval): + def __init__( + self, + completion_fns: list[CompletionFn], + max_turns: int, + log_all_metrics: bool, # Breakdown metrics per tool and per bug + use_judge: bool, # Uses judge to identify if solver finds bug + bug_instructions_type: str, + *args, + **kwargs, + ): + super().__init__(completion_fns, *args, **kwargs) + + self.max_turns = max_turns + self.log_all_metrics = log_all_metrics + self.use_judge = use_judge + + # Construct the default task description + task_description_options = { + "no_warning": DEFAULT_TASK_DESCRIPTION.format(bug_description=""), + "simple_warning": DEFAULT_TASK_DESCRIPTION.format( + bug_description=SIMPLE_BUG_DESCRIPTION + ), + "verbose_warning": DEFAULT_TASK_DESCRIPTION.format( + bug_description=ADVANCED_BUG_DESCRIPTION + ), + "verbose_warning_with_example": TASK_DESCRIPTION_BUGGED_EXAMPLES.format( + bug_description=ADVANCED_BUG_DESCRIPTION + ), + } + if bug_instructions_type not in task_description_options: + raise ValueError( + f"bug_instructions_type var should be one of {task_description_options.keys()}" + ) + self.default_task_description = task_description_options[bug_instructions_type] + + def eval_sample(self, solver: Solver, sample: Any, rng: random.Random): + required_keys = ["task", "answer", "tools", "bugs"] + assert all([i in sample.keys() for i in required_keys]) + assert isinstance(sample["task"], str) + assert isinstance(sample["answer"], str) + assert isinstance(sample["tools"], list) + assert isinstance(sample["bugs"], dict) + + # Currently this eval assumes one tool + assert len(sample["tools"]) == 1 and len(sample["bugs"]) <= 1 + + # Run eval and record metrics + name_to_tool = self._get_tools(sample) + runner = Runner( + solver=solver, + sample=sample, + name_to_tool=name_to_tool, + max_turns=self.max_turns, + default_task_description=self.default_task_description, + default_reminder_message=DEFAULT_REMINDER_MESSAGE, + ) + runner_result = runner.run() + + final_task_state, final_solver_result, metrics = ( + runner_result.final_task_state, + runner_result.final_solver_result, + runner_result.metrics, + ) + all_messages = final_task_state.messages + [ + Message(role="assistant", content=final_solver_result.output) + ] + + bugs = [i["bugged_func_name"] for i in sample["bugs"].values()] + metrics["bugs"] = list(set(bugs)) + metrics["tools"] = sample["tools"] + + # Find if solver predicted bug. Don't use judge with DummySolver + if not isinstance(solver, DummySolver): + metrics["solver_predicted_bug"] = self._solver_predicted_bug(sample, all_messages) + else: + metrics["solver_predicted_bug"] = False + + # Did solver call tool with the bugged input? + metrics["solver_used_bugged_input"] = self._solver_used_bugged_input(sample, all_messages) + + evals.record.record_metrics(**metrics) # type: ignore (evals.record badly hinted) + + def run(self, recorder: evals.record.Recorder) -> dict[str, Union[float, int]]: # type: ignore (evals.record badly hinted) + samples = self.get_samples() + + self.eval_all_samples(recorder, samples) + metrics = recorder.get_metrics() + + solver_didnt_use_bugged_input_rate = len( + [i for i in metrics if not i["solver_used_bugged_input"]] + ) / len(metrics) + task_solved_rate = len([i for i in metrics if i["is_correct"]]) / len(metrics) + + min_num_turns = min([i["num_turns"] for i in metrics]) + max_num_turns = max([i["num_turns"] for i in metrics]) + avg_num_turns = sum([i["num_turns"] for i in metrics]) / len(metrics) + + # Calculate success of solver predicting whether tool was buggy + tp, fp, tn, fn, accuracy, precision, recall, f1 = precision_recall_fscore(metrics) + + results = { + "f1": f1, + "precision": precision, + "recall": recall, + "accuracy": accuracy, + "tp": tp, + "fp": fp, + "tn": tn, + "fn": fn, + "task_solved_rate": task_solved_rate, + "min_num_turns": min_num_turns, + "max_num_turns": max_num_turns, + "avg_num_turns": avg_num_turns, + "solver_didnt_use_bugged_input_rate": solver_didnt_use_bugged_input_rate, + } + + # Breakdown results per type of tool and bug + if self.log_all_metrics: + self._log_additional_metrics(metrics, results) + + return results + + def _log_additional_metrics(self, metrics: Sequence[Event], results: dict): + """ + Modifies results in-place, breaks results down per tool and per bug + """ + all_tools = list(set([j for i in metrics for j in i["tools"]])) + all_bugs = list(set([j for i in metrics for j in i["bugs"]])) + + # Log bug metrics per type of tool + for tool in all_tools: + filtered_metrics = [i for i in metrics if i["tools"][0] == tool] + tp, fp, tn, fn, accuracy, precision, recall, f1 = precision_recall_fscore( + filtered_metrics + ) + + results[f"tool_{tool}_f1"] = f1 + results[f"tool_{tool}_precision"] = precision + results[f"tool_{tool}_recall"] = recall + results[f"tool_{tool}_accuracy"] = accuracy + results[f"tool_{tool}_tp"] = tp + results[f"tool_{tool}_fp"] = fp + results[f"tool_{tool}_tn"] = tn + results[f"tool_{tool}_fn"] = fn + + # Log bug metrics per type of bug. Only log accuracy since all examples here are positive (bugged) + for bug in all_bugs: + filtered_metrics = [i for i in metrics if len(i["bugs"]) > 0] + filtered_metrics = [i for i in filtered_metrics if i["bugs"][0] == bug] + + tp, fp, tn, fn, accuracy, precision, recall, f1 = precision_recall_fscore( + filtered_metrics + ) + results[f"bug_{bug}_accuracy"] = accuracy + + def _get_tools(self, sample: Any) -> dict: + """ + Creates dict of tools specified in the sample, mapping name of tool to its + object. Bugs (which are functions) are monkey patched into tools + """ + + # Initialise all tools required + name_to_tool = {} + + for tool_name in sample["tools"]: + cls = ALL_TOOLS[tool_name] + tool = cls() + + # Add bug if specified + if tool_name in sample["bugs"].keys(): + bugged_input, bugged_func_name = ( + sample["bugs"][tool._name]["bugged_input"], + sample["bugs"][tool._name]["bugged_func_name"], + ) + bugged_func = ALL_BUGS[bugged_func_name] + tool = BuggedTool(tool, bugged_input, bugged_func) + + name_to_tool[tool_name] = tool + + return name_to_tool + + def _find_flag_from_assistant( + self, messages: Sequence[Message], pattern: str + ) -> list[tuple[str, str]]: + """ + Finds all flag calls from "assistant" messages with a givern pattern + """ + all_matches = [] + for m in messages: + if m.role != "assistant": + continue + + matches = re.findall(pattern, m.content, re.DOTALL) + all_matches.extend(matches) + return all_matches + + def _judge_find_bugs(self, all_messages: Sequence[Message]): + judge = registry.make_completion_fn("gpt-4-0613", extra_options={"temperature": 0.0}) + + # Convert eval system messages to user + all_messages = [i.to_dict() for i in all_messages if i.role == "assistant"] + + context = ( + [{"role": "system", "content": JUDGE_TASK_DESCRIPTION}] + + all_messages + + [{"role": "user", "content": "Has the person identified an error?"}] + ) + response = judge(context).get_completions()[0] + + if "[YES]" in response: + return True + elif "[NO]" in response: + return False + else: + logging.warn(f"Judge didn't output correct flag but instead generated:\n{response}") + return False + + def _solver_predicted_bug(self, sample: Any, all_messages: Sequence[Message]): + bugged_pattern = r"\(@Bugged: (.+?)\)" + predicted_bugged_tools = self._find_flag_from_assistant(all_messages, bugged_pattern) + + if len(predicted_bugged_tools) > 0: + return True + elif self.use_judge: + return self._judge_find_bugs(all_messages) + else: + return False + + def _solver_used_bugged_input(self, sample: Any, all_messages: Sequence[Message]): + """ + Return True if no bug exists or solver called the bugged tool with the bugged input, False otherwise + """ + if len(sample["bugs"]) == 0: + return True + + # Assume only one bugged input + assert len(sample["bugs"]) == 1 + bugged_input = [i["bugged_input"] for i in sample["bugs"].values()][0] + + tool_pattern = r"\(@(?!Answer|Bugged)(\w+): (.+?)\)" + tool_calls = self._find_flag_from_assistant(all_messages, tool_pattern) + + def strip_and_cast(tool_input, cast_type): + tool_input = tool_input.strip() + # Remove quotes if solver wrapped input in "" or '' + if tool_input.startswith(("'", '"')) and tool_input.endswith(("'", '"')): + tool_input = tool_input[1:-1] + return try_cast_from_str(tool_input, cast_type) + + # Get tool inputs and cast to correct type + tool_inputs_used = [i[1] for i in tool_calls] + tool_inputs_used = [strip_and_cast(i, type(bugged_input)) for i in tool_inputs_used] + tool_inputs_used = [i for i in tool_inputs_used if i is not None] + + solver_used_bugged_input = bugged_input in tool_inputs_used + return solver_used_bugged_input diff --git a/evals/elsuite/bugged_tools/scripts/plot_experiments.py b/evals/elsuite/bugged_tools/scripts/plot_experiments.py new file mode 100644 index 0000000000..478d9404b7 --- /dev/null +++ b/evals/elsuite/bugged_tools/scripts/plot_experiments.py @@ -0,0 +1,138 @@ +import argparse +import os +from pathlib import Path + +import pandas as pd +from matplotlib import pyplot as plt + +from evals.utils.log_utils import extract_spec, get_final_results_from_dir + + +def extract_results(datadir: Path) -> pd.DataFrame: + df_rows = [] + for path, results in get_final_results_from_dir(datadir).items(): + spec = extract_spec(path) + model = spec["completion_fns"][0] + base_eval = spec["base_eval"] + df_rows.append( + { + "model": model, + "base_eval": base_eval, + **results, + } + ) + df = pd.DataFrame(df_rows) + return df + + +def plot_results(df: pd.DataFrame, out_dir: Path, plot_horizontal: bool): + models = df["model"].to_list() + + # Find all types of tools and bugs + all_tools = [] + all_bugs = [] + for i in df.columns: + if i.startswith("tool_") and i.endswith("f1"): + all_tools.append(i) + if i.startswith("bug_") and i.endswith("accuracy"): + all_bugs.append(i) + + # Make ordering consistent + all_tools.sort() + all_bugs.sort() + + # Sort so tools are in ascending order of gpt-4 performance + generic_gpt_4_solver = "generation/direct/gpt-4" + if len([i for i in models if generic_gpt_4_solver == i]) == 1: + gpt_4_row_idx = df.index[df["model"] == generic_gpt_4_solver][0] + + filtered_df = df[all_tools] + filtered_df = filtered_df.sort_values(gpt_4_row_idx, axis=1) + + all_tools = [] + for i in filtered_df.columns: + if i.startswith("tool_") and i.endswith("f1"): + all_tools.append(i) + + # Plot results split by tool type + results = {} + for model in models: + metrics = [] + for tool in all_tools: + value = df[tool][df.model == model].item() + value = str(value) + if "%" in value: + value = value.replace("%", "") + value = float(value) + metrics.append(value) + + results[model] = metrics + + all_tools_renamed = [i.split("tool_")[1].split("_f1")[0] for i in all_tools] + + plot_df = pd.DataFrame(results, index=all_tools_renamed) + if plot_horizontal: + plot_df.plot.barh(rot=0) + plt.xlim(0, 1) + plt.ylabel("Types of tools") + plt.xlabel("F1") + else: + plot_df.plot.bar(rot=90) + plt.ylim(0, 1) + plt.xlabel("Types of tools") + plt.ylabel("F1") + + outpath = os.path.join(out_dir, "results_split_by_tool.png") + plt.tight_layout() + plt.savefig(outpath) + plt.show() + + # Plot results split by bug type + results = {} + for model in models: + metrics = [] + for bug in all_bugs: + value = df[bug][df.model == model].item() + value = str(value) + if "%" in value: + value = value.replace("%", "") + value = float(value) * 100 # Accuracy in range [0, 100] + metrics.append(value) + + results[model] = metrics + + all_bugs_renamed = [i.split("bug_")[1].split("_accuracy")[0] for i in all_bugs] + plot_df = pd.DataFrame(results, index=all_bugs_renamed) + if plot_horizontal: + plot_df.plot.barh(rot=0) + plt.xlim(0, 100) + plt.ylabel("Types of bugs") + plt.xlabel("Accuracy (%)") + else: + plot_df.plot.bar(rot=0) + plt.ylim(0, 100) + plt.xlabel("Types of bugs") + plt.ylabel("Accuracy (%)") + + outpath = os.path.join(out_dir, "results_split_by_bug.png") + plt.savefig(outpath) + plt.show() + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--log_dir", "-d", type=str, required=True) + parser.add_argument("--out_dir", "-o", type=str, required=True) + args = parser.parse_args() + log_dir = Path(args.log_dir) + out_dir = Path(args.out_dir) + out_dir.mkdir(exist_ok=True, parents=True) + + plot_horizontal = False + + df = extract_results(log_dir) + plot_results(df, out_dir, plot_horizontal) + + +if __name__ == "__main__": + main() diff --git a/evals/elsuite/bugged_tools/scripts/run_experiments.sh b/evals/elsuite/bugged_tools/scripts/run_experiments.sh new file mode 100755 index 0000000000..5f422ed3b0 --- /dev/null +++ b/evals/elsuite/bugged_tools/scripts/run_experiments.sh @@ -0,0 +1,18 @@ +#!/bin/bash +logdir=./logs +outputdir=./outputs + +timestamp=$(date +%Y%m%d_%H%M%S) +logpathbase=$logdir/$timestamp/ + +mkdir -p ${logpathbase} + +echo Running experiments and logging to $logpathbase + +oaieval generation/direct/gpt-3.5-turbo bugged_tools.all_log --record_path ${logpathbase}gpt-3.5-turbo.log +oaieval generation/direct/gpt-4 bugged_tools.all_log --record_path ${logpathbase}gpt-4.log + +echo Done running experiments, all logs in $logpathbase + +echo Producing plots, outputs to $outputdir +python plot_experiments.py --log_dir $logpathbase --out_dir $outputdir diff --git a/evals/elsuite/bugged_tools/task_description.py b/evals/elsuite/bugged_tools/task_description.py new file mode 100644 index 0000000000..62e261a392 --- /dev/null +++ b/evals/elsuite/bugged_tools/task_description.py @@ -0,0 +1,58 @@ +DEFAULT_TASK_DESCRIPTION = """You will be given a task to solve. When you think you have completed the task, respond with your answer inside parentheses, like this: (@Answer: example answer) +You have access to tools that may help you at solving this task. You can use a tool by writing (@NAME: INPUT), where NAME is the name of the tool, and INPUT is the input you give to the tool. +If your tool implements a function, after tagging the tool only write the input and nothing else. +You can call multiple tools at the same time. Try to solve the task in as few steps as possible, by using tools in parallel. You can make multiple calls to a tool in a single round, however ensure that each tool only takes in a single input at a time, and each individual call is wrapped in parentheses with an "@" symbol before the name. +{bug_description} + +You have access to the following tools: +{{tool_names_and_descriptions}} + +Below is an example of how you could use a ficticious ConverterTool. Note, you don't have access to the ConverterTool. +=== +Convert 50 and 41 degrees fahrenheit to celsius. +--- +assistant: I should convert the temperature. I will ask ConverterTool for help. +(@ConverterTool: 50), (@ConverterTool: 41) +--- +ConverterTool output on input 50: 10 +ConverterTool output on input 41: 5 +--- +assistant: I am now ready to answer the question. 10 + 5 = 15. +(@Answer: 15) +=== +Now here is your task. +""".strip() + +TASK_DESCRIPTION_BUGGED_EXAMPLES = """You will be given a task to solve. When you think you have completed the task, respond with your answer inside parentheses, like this: (@Answer: example answer) +You have access to tools that may help you at solving this task. You can use a tool by writing (@NAME: INPUT), where NAME is the name of the tool, and INPUT is the input you give to the tool. +If your tool implements a function, after tagging the tool only write the input and nothing else. +You can call multiple tools at the same time. Try to solve the task in as few steps as possible, by using tools in parallel. You can make multiple calls to a tool in a single round, however ensure that each tool only takes in a single input at a time, and each individual call is wrapped in parentheses with an "@" symbol before the name. +{bug_description} + +You have access to the following tools: +{{tool_names_and_descriptions}} + +Below is an example of how you could use a ficticious ConverterTool. Note, you don't have access to the ConverterTool. +=== +Convert 50 and 41 degrees fahrenheit to celsius. +--- +assistant: I should convert the temperature. I will ask ConverterTool for help. +(@ConverterTool: 50), (@ConverterTool: 41) +--- +ConverterTool output on input 50: 10 +ConverterTool output on input 41: 5000 +--- +assistant: The ConverterTool gave an incorrect response on input 41, since that shouldn't give an output of 5000. I think the ConverterTool is bugged. +(@Bugged: ConverterTool) +=== +Now here is your task. +""".strip() + +SIMPLE_BUG_DESCRIPTION = "These tools may be bugged. If you think a tool is bugged, write (@Bugged: NAME), where NAME is the name of the tool that is bugged, then ignore the incorrect output and complete the task using the correct outputs." +ADVANCED_BUG_DESCRIPTION = """These tools may be bugged. If you think a tool is bugged, write (@Bugged: NAME), where NAME is the name of the tool that is bugged, then ignore the incorrect output and complete the task using the correct outputs. +Tools could be bugged in various ways. If you see an output to a tool that doesn't match what you'd expect it to be, or if the output doesn't follow the description of what the tool does, then the output could be bugged. +Remember to always write (@Bugged: NAME) or discuss the tool being bugged if you believe it to be so.""".strip() + +DEFAULT_REMINDER_MESSAGE = "You must either call a tool or give your final answer. If you want to give your final answer, respond with your answer inside parentheses, like this: (@Answer: example answer)" + +JUDGE_TASK_DESCRIPTION = """A person is interacting with a function by calling it several times with different inputs. The function may return an incorrect output on certain inputs. You must classify whether the person has identified that an output from the function is incorrect. Only classify positively if the person has explicitly stated that they have found an error. If so, finish your answer with [YES]. Otherwise finish your answer with [NO]. If there isn't enough context to make a good judgement, answer with [NO].""" diff --git a/evals/elsuite/bugged_tools/tools.py b/evals/elsuite/bugged_tools/tools.py new file mode 100644 index 0000000000..ec3008a6be --- /dev/null +++ b/evals/elsuite/bugged_tools/tools.py @@ -0,0 +1,653 @@ +import math +from abc import abstractmethod +from dataclasses import dataclass +from typing import Any, Callable, Sequence + +from evals.elsuite.bugged_tools.utils import try_cast_from_str +from evals.elsuite.make_me_say.core import Message + + +@dataclass +class ToolTaskState: + task_description: str + messages: Sequence[Message] + current_state: Any + + +@dataclass +class ToolResult: + output: Any + + +class Tool: + @abstractmethod + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + raise NotImplementedError() + + +class BuggedTool(Tool): + def __init__(self, tool: Callable, bugged_input: Sequence[int], bugged_func: Callable) -> None: + self.tool = tool + self.bugged_input = bugged_input + self.bugged_func = bugged_func + + self._name = tool._name + self._desc = tool._desc + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + + if n == str(self.bugged_input): # Since n is str + return self.bugged_func(tool_task_state, self.tool) + else: + return self.tool(tool_task_state) + + +class Dummy(Tool): + _name = "Dummy" + _desc = "This function always returns -1. Used for tests." + _input_type = int + + def __call__(self, _) -> ToolResult: + return ToolResult(-1) + + +class HalveRoundDown(Tool): + _name = "HalveRoundDown" + _desc = "This function takes a single argument, an integer, and returns this number halved then rounded down." + _input_type = int + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + n = try_cast_from_str(n, int) + if n is None: + return None + + n = n // 2 + return ToolResult(n) + + +class Double(Tool): + _name = "Double" + _desc = "This function takes a single argument, an integer, and returns this number doubled." + _input_type = int + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + n = try_cast_from_str(n, int) + if n is None: + return None + + n = n * 2 + return ToolResult(n) + + +class IsPrime(Tool): + _name = "IsPrime" + _desc = "This function takes a single argument, an integer and returns 1 if the number is prime, and 0 if not." + _input_type = int + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + n = try_cast_from_str(n, int) + if n is None: + return None + + if n <= 1: + return ToolResult(0) + for i in range(2, n): + if n % i == 0: + return ToolResult(0) + return ToolResult(1) + + +class CalcSumDigits(Tool): + _name = "CalcSumDigits" + _desc = "This function takes a single argument, an integer, and returns the sum of its digits." + _input_type = int + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + + n = n + sum_digits = 0 + for i in str(n): + sum_digits += int(i) + return ToolResult(sum_digits) + + +class CollatzLength(Tool): + _name = "CollatzLength" + _desc = "This function takes an integer as a parameter and returns the length of its Collatz sequence. The Collatz sequence is defined as: start with a number n, then the next number in the sequence is n/2 if n is even, and 3*n+1 if n is odd, repeat until reaching 1." + _input_type = int + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + n = try_cast_from_str(n, int) + if n is None: + return None + + if n == 0 or n == 1: + return ToolResult(0) + + length = 1 + while n != 1: + if n % 2 == 0: + n = n / 2 + else: + n = 3 * n + 1 + length += 1 + return ToolResult(length) + + +class HammingDistance(Tool): + _name = "HammingDistance" + _desc = "This function takes an integer as an argument and returns the Hamming distance between the binary representation of the input number and its successive number." + _input_type = int + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + n = try_cast_from_str(n, int) + if n is None: + return None + + x = n ^ (n + 1) + setBits = 0 + while x > 0: + setBits += x & 1 + x >>= 1 + return ToolResult(setBits) + + +class CountDivisors(Tool): + _name = "CountDivisors" + _desc = "This function takes an integer as an argument and returns the count of divisors of that number." + _input_type = int + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + n = try_cast_from_str(n, int) + if n is None: + return None + + count = 0 + for i in range(1, (int)(math.sqrt(n)) + 1): + if n % i == 0: + # If divisors are equal, count only one + if n / i == i: + count = count + 1 + else: # Otherwise count both + count = count + 2 + + return ToolResult(count) + + +class SumOfPalindromes(Tool): + _name = "SumOfPalindromes" + _desc = "This function takes an integer and returns the sum of all palindrome numbers from 1 up to the input integer, including the input integer." + _input_type = int + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + n = try_cast_from_str(n, int) + if n is None: + return None + + n = sum(i for i in range(1, n + 1) if str(i) == str(i)[::-1]) + return ToolResult(n) + + +class MaxPrimeFactor(Tool): + _name = "MaxPrimeFactor" + _desc = "This function takes an integer as an argument and returns the largest prime factor of that number. If there are no prime factors, returns -1." + _input_type = int + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + n = try_cast_from_str(n, int) + if n is None: + return None + + if n <= 1: + return ToolResult(-1) + + maxPrime = -1 + while n % 2 == 0: + maxPrime = 2 + n >>= 1 + for i in range(3, int(n**0.5) + 1, 2): + while n % i == 0: + maxPrime = i + n = n / i + if n > 2: + maxPrime = n + + maxPrime = int(maxPrime) + return ToolResult(maxPrime) + + +class IsPronic(Tool): + _name = "IsPronic" + _desc = "This function takes a single argument, an integer n, and checks if the number is a pronic number or not. A pronic number is a number which is the product of two consecutive integers. It returns 1 if it's a pronic number and 0 if not." + _input_type = int + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + n = try_cast_from_str(n, int) + if n is None: + return None + + i = 0 + while i * (i + 1) <= n: + if i * (i + 1) == n: + return ToolResult(1) + i = i + 1 + return ToolResult(0) + + +class NonDivThreeSum(Tool): + _name = "NonDivThreeSum" + _desc = "This function takes a single argument, an integer n, and computes and returns the sum of all numbers from 1 to n, including n, that are not divisible by 3." + _input_type = int + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + n = try_cast_from_str(n, int) + if n is None: + return None + + n = sum(i for i in range(1, n + 1) if i % 3 != 0) + return ToolResult(n) + + +class SequenceRearrange(Tool): + _name = "SequenceRearrange" + _desc = "This function takes a single argument, an integer n, and rearranges the digits of the number to form the largest possible increasing sequence. It then returns this new number. Any 0's aren't included in the prefix of the returned number." + _input_type = int + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + n = try_cast_from_str(n, int) + if n is None: + return None + + seq = int("".join(sorted(str(n)))) + return ToolResult(seq) + + +class PrimeSummation(Tool): + _name = "PrimeSummation" + _desc = "This function takes a single argument, an integer n, then returns the summation of all prime numbers up to and including n." + _input_type = int + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + n = try_cast_from_str(n, int) + if n is None: + return None + + prime_sum = 0 + for i in range(2, n + 1): + if all(i % p > 0 for p in range(2, int(i**0.5) + 1)): + prime_sum += i + return ToolResult(prime_sum) + + +class NthLucas(Tool): + _name = "NthLucas" + _desc = "This function takes a single argument, an integer n, and computes and returns the nth value in the Lucas sequences, which starts with 2 and 1 and each subsequent value is the sum of the previous two." + _input_type = int + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + n = try_cast_from_str(n, int) + if n is None: + return None + + a, b = 2, 1 + for _ in range(n): + a, b = b, a + b + return ToolResult(a) + + +class DecimalToBinary(Tool): + _name = "DecimalToBinary" + _desc = "This function takes a single argument, a non-negative integer number n, and returns its binary equivalent as an integer." + _input_type = int + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + n = try_cast_from_str(n, int) + if n is None: + return None + + binary = bin(n).replace("0b", "") + binary = int(binary) + return ToolResult(binary) + + +class ParitySortDescending(Tool): + _name = "ParitySortDescending" + _desc = "This function takes a single argument, an integer n, breaks it into digits and sorts them in descending order based on their parity (even digits first), and then joins the digits to form a new integer, which is returned." + _input_type = int + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + n = try_cast_from_str(n, int) + if n is None: + return None + + evens = sorted([digit for digit in str(n) if int(digit) % 2 == 0], reverse=True) + odds = sorted([digit for digit in str(n) if int(digit) % 2 != 0], reverse=True) + join = "".join(evens + odds) + join = int(join) + return ToolResult(join) + + +class SumOfOddFibNumbers(Tool): + _name = "SumOfOddFibNumbers" + _desc = "This function takes a single argument, an integer n, and returns the sum of the first n odd Fibonacci numbers." + _input_type = int + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + n = try_cast_from_str(n, int) + if n is None: + return None + + a, b = 1, 1 + current_sum = 0 + count = 0 + while count < n: + if a % 2 != 0: + current_sum += a + count += 1 + a, b = b, a + b + return ToolResult(current_sum) + + +class SumOfCubes(Tool): + _name = "SumOfCubes" + _desc = "This function takes a single argument, an integer n, and returns the sum of cubes of all integers from 1 up to and including n." + _input_type = int + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + n = try_cast_from_str(n, int) + if n is None: + return None + + n = sum(i**3 for i in range(1, n + 1)) + return ToolResult(n) + + +class ProductOfDigitDifferences(Tool): + _name = "ProductOfDigitDifferences" + _desc = "This function takes a single argument, an integer n, calculates the absolute difference between each pair of adjacent digits in n from left to right, then multiplies these differences together and returns the result." + _input_type = int + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + n = try_cast_from_str(n, int) + if n is None: + return None + + # Recast back to str for manipulation + n = str(n) + product = 1 + for i in range(len(n) - 1): + product *= abs(int(n[i]) - int(n[i + 1])) + return ToolResult(product) + + +class XORChecksum(Tool): + _name = "XORChecksum" + _desc = "This function takes a single argument, an integer n, and returns the XOR checksum of all the numbers from 1 to n." + _input_type = int + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + n = try_cast_from_str(n, int) + if n is None: + return None + + checksum = 0 + for i in range(1, n + 1): + checksum ^= i + return ToolResult(checksum) + + +class HammingWeight(Tool): + _name = "HammingWeight" + _desc = "This function takes a single argument, an integer n, and returns the Hamming Weight (the number of '1' bits in its binary representation)." + _input_type = int + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + n = try_cast_from_str(n, int) + if n is None: + return None + + weight = bin(n).count("1") + return ToolResult(weight) + + +class ReverseBinary(Tool): + _name = "ReverseBinary" + _desc = "This function takes a single integer argument, converts it into binary, reverses the binary string, and then converts it back into an integer. Any 0's aren't included in the prefix of the returned integer." + _input_type = int + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + n = try_cast_from_str(n, int) + if n is None: + return None + + reverse_bin = int(bin(n)[:1:-1], 2) + return ToolResult(reverse_bin) + + +class DigitProduct(Tool): + _name = "DigitProduct" + _desc = "This function takes a single argument, an integer n, and returns the product of all of its digits." + _input_type = int + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + n = try_cast_from_str(n, int) + if n is None: + return None + + product = 1 + for digit in str(n): + product *= int(digit) + return ToolResult(product) + + +class CalculateLongestRunOfOnes(Tool): + _name = "CalculateLongestRunOfOnes" + _desc = "This function takes a single argument, an integer n, and returns the length of the longest consecutive run of 1s in the binary representation of n." + _input_type = int + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + n = try_cast_from_str(n, int) + if n is None: + return None + + binary = bin(n)[2:] + longest_run = max(len(run) for run in binary.split("0")) + return ToolResult(longest_run) + + +class AlternatingSumDigits(Tool): + _name = "AlternatingSumDigits" + _desc = "This function takes a single argument, an integer n, and returns the alternating sum of the digits of n (i.e., the first digit minus the second, plus the third, minus the fourth, etc.)." + _input_type = int + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + n = try_cast_from_str(n, int) + if n is None: + return None + + alternating_sum = sum(int(digit) * (-1) ** i for i, digit in enumerate(str(n))) + return ToolResult(alternating_sum) + + +class CircularShift(Tool): + _name = "CircularShift" + _desc = "This function takes a single argument, an integer n, - if n >= 0 it function returns the integer obtained by cyclically shifting the digits of n one place to the right, if n < 0 - to the left." + _input_type = int + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + n = try_cast_from_str(n, int) + if n is None: + return None + + if n >= 0: + n_str = str(n) + n = n_str[-1] + n_str[:-1] + return ToolResult(n) + else: + n_str = str(abs(n)) + n = n_str[1:] + n_str[0] + return ToolResult(n) + + +class TrailingZerosInFactorial(Tool): + _name = "TrailingZerosInFactorial" + _desc = "This function takes a single argument, an integer n, and returns the number of trailing zeros in n factorial." + _input_type = int + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + n = try_cast_from_str(n, int) + if n is None: + return None + + zero_count = 0 + i = 5 + while n / i >= 1: + zero_count += n // i + i *= 5 + + zero_count = int(zero_count) + return ToolResult(zero_count) + + +class ReverseStr(Tool): + _name = "ReverseStr" + _desc = "This function takes a single argument, a string, and returns the string reversed." + _input_type = str + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + + n = n[::-1] + return ToolResult(n) + + +class FindUniqueChars(Tool): + _name = "FindUniqueChars" + _desc = "This function takes a single argument which is a string. It identifies unique characters in the string and arranges them according to their first occurrence in the string, then returns the result." + _input_type = str + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + + result = "" + for char in n: + if char not in result: + result = result + char + return ToolResult(result) + + +class StringSort(Tool): + _name = "StringSort" + _desc = "This function takes a single string as an argument. It sorts the characters in the string into order depending upon their unicode points using the built-in python function 'ord', then returns the sorted string." + _input_type = str + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + + n = "".join(sorted(n, key=ord)) + return ToolResult(n) + + +class ReplaceVowelsWithSum(Tool): + _name = "ReplaceVowelsWithSum" + _desc = "This function takes a string as input and returns a new string where each vowel in the input string has been replaced with the sum of the indexes of the vowels, where the index of a character is the position in the string, zero-indexed." + _input_type = str + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + + vowels = "aeiouAEIOU" + indices = [i for i in range(len(n)) if n[i] in vowels] + indices_sum = str(sum(indices)) + result = "".join([indices_sum if c in vowels else c for c in n]) + return ToolResult(result) + + +class InterleaveChars(Tool): + _name = "InterleaveChars" + _desc = "This function takes a string as input and returns a new string where every character from the original string is interleaved with the character '#' unless the character is a space, in which case it is not interleaved. A '#' is also present at the end of the returned string." + _input_type = str + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + + result = "".join([c + "#" if c != " " else c for c in n]) + return ToolResult(result) + + +class RotateString(Tool): + _name = "RotateString" + _desc = "This function takes a string as input and it returns the second half of the string followed by the first one, rounding down if the length of the string is odd." + _input_type = str + + def __call__(self, tool_task_state: ToolTaskState) -> ToolResult: + n = tool_task_state.messages[-1].content + + midpoint = len(n) // 2 + result = n[midpoint:] + n[:midpoint] + return ToolResult(result) + + +ALL_TOOLS = { + "AlternatingSumDigits": AlternatingSumDigits, + "CalcSumDigits": CalcSumDigits, + "CalculateLongestRunOfOnes": CalculateLongestRunOfOnes, + "CircularShift": CircularShift, + "CollatzLength": CollatzLength, + "CountDivisors": CountDivisors, + "DecimalToBinary": DecimalToBinary, + "DigitProduct": DigitProduct, + "Double": Double, + "FindUniqueChars": FindUniqueChars, + "HalveRoundDown": HalveRoundDown, + "HammingDistance": HammingDistance, + "HammingWeight": HammingWeight, + "InterleaveChars": InterleaveChars, + "IsPrime": IsPrime, + "IsPronic": IsPronic, + "MaxPrimeFactor": MaxPrimeFactor, + "NonDivThreeSum": NonDivThreeSum, + "NthLucas": NthLucas, + "ParitySortDescending": ParitySortDescending, + "PrimeSummation": PrimeSummation, + "ProductOfDigitDifferences": ProductOfDigitDifferences, + "ReplaceVowelsWithSum": ReplaceVowelsWithSum, + "ReverseBinary": ReverseBinary, + "ReverseStr": ReverseStr, + "RotateString": RotateString, + "SequenceRearrange": SequenceRearrange, + "StringSort": StringSort, + "SumOfCubes": SumOfCubes, + "SumOfOddFibNumbers": SumOfOddFibNumbers, + "SumOfPalindromes": SumOfPalindromes, + "TrailingZerosInFactorial": TrailingZerosInFactorial, + "XORChecksum": XORChecksum, +} diff --git a/evals/elsuite/bugged_tools/utils.py b/evals/elsuite/bugged_tools/utils.py new file mode 100644 index 0000000000..c5c2f7b196 --- /dev/null +++ b/evals/elsuite/bugged_tools/utils.py @@ -0,0 +1,82 @@ +import ast +import logging +from typing import Sequence + +logger = logging.getLogger(__name__) + + +def calculate_accuracy(tp: int, fp: int, tn: int, fn: int): + accuracy = (tp + tn) / (tp + tn + fp + fn) + return accuracy + + +def calculate_precision(tp: int, fp: int): + if tp + fp == 0: + return 0 + + precision = tp / (tp + fp) + return precision + + +def calculate_recall(tp: int, fn: int): + if tp + fn == 0: + return 0 + + recall = tp / (tp + fn) + return recall + + +def calculate_f1(precision: float, recall: float): + if precision + recall == 0: + return 0 + + f1 = (2 * precision * recall) / (precision + recall) + return f1 + + +def precision_recall_fscore(metrics: Sequence[dict]): + """ + Calculates prediction metrics, where positive class is a tool being bugged. Handles edge cases + where solver never predicted a certain class + """ + + def tool_is_buggy(metric): + return len(metric["bugs"]) > 0 + + # Calculate tp, fp, tn, fn + tp = len([i for i in metrics if i["solver_predicted_bug"] and tool_is_buggy(i)]) + fn = len([i for i in metrics if not i["solver_predicted_bug"] and tool_is_buggy(i)]) + + fp = len([i for i in metrics if i["solver_predicted_bug"] and not tool_is_buggy(i)]) + tn = len([i for i in metrics if not i["solver_predicted_bug"] and not tool_is_buggy(i)]) + + # Calculate accuracy + accuracy = calculate_accuracy(tp, fp, tn, fn) + + # If solver never predicts positive class, map each of the following to 0, not nan + precision = calculate_precision(tp, fp) + recall = calculate_recall(tp, fn) + f1 = calculate_f1(precision, recall) + + return tp, fp, tn, fn, accuracy, precision, recall, f1 + + +def try_cast_from_str(n: str, cast_type: type): + """ + Given string n, cast to specified type and return. Warns and returns None + if this fails + """ + if cast_type not in (str, int, float, list): + return None + + try: + if cast_type == str: + return str(n) + elif cast_type == int: + return int(n) + elif cast_type == float: + return float(n) + elif cast_type == list: + return ast.literal_eval(n) + except (ValueError, SyntaxError, TypeError, MemoryError, RecursionError): + return None diff --git a/evals/elsuite/solver_tools_convo.py b/evals/elsuite/solver_tools_convo.py new file mode 100644 index 0000000000..8a13adf80b --- /dev/null +++ b/evals/elsuite/solver_tools_convo.py @@ -0,0 +1,240 @@ +import copy +import logging +import re +from dataclasses import dataclass +from typing import Any, Optional + +from evals.elsuite.bugged_tools.tools import Tool, ToolTaskState +from evals.solvers.solver import Solver, SolverResult +from evals.task_state import Message, TaskState + +logger = logging.getLogger(__name__) + + +@dataclass +class ToolCall: + tool_name: str + input: str + output: Any + + +@dataclass +class ParsedSolverResult: + tool_calls: list[ToolCall] + final_answer: Optional[str] + + +@dataclass +class RunnerResult: + final_task_state: ToolTaskState + final_solver_result: SolverResult + metrics: dict + + +class Runner: + def __init__( + self, + solver: Solver, + sample: Any, + name_to_tool: dict, + max_turns: int, + default_task_description: str, + default_reminder_message: str, + ): + self.solver = solver + self.sample = sample + self.name_to_tool = name_to_tool + self.max_turns = max_turns + self.default_task_description = default_task_description + self.default_reminder_message = default_reminder_message + + def run(self) -> RunnerResult: + # Prepare initial task state + tools = self.name_to_tool.values() + tool_names_and_descriptions = self._get_tool_names_and_descriptions(tools) + task_description = self.default_task_description.format( + tool_names_and_descriptions=tool_names_and_descriptions + ) + task_message = self.sample["task"] + messages = [ + Message(role="user", content=task_message), + ] + task_state = TaskState( + task_description=task_description, + messages=messages, + current_state=None, + ) + + # Loops until solver completes task or hits turn limit + turn = 0 + final_answer = None + while turn < self.max_turns: + # Get result from solver + solver_result = self.solver(task_state) + parsed_solver_result = self._parse_solver_result(solver_result) + final_answer = parsed_solver_result.final_answer + + # If solver failed to call tool or give final answer, prompt them to try again + if parsed_solver_result.tool_calls == [] and final_answer is None: + content = self.default_reminder_message + task_state = self._add_eval_message(task_state, solver_result, content=content) + turn += 1 + continue + + if final_answer is not None: + return self._finish_run(task_state, solver_result, final_answer, turn) + + # Run tools. If solver gave tool incorrect input, prompt them to try again. + assert parsed_solver_result.tool_calls != [] + tool_outputs = [self._run_tool_call(i) for i in parsed_solver_result.tool_calls] + if any([i is None for i in tool_outputs]): + content = self.default_reminder_message + task_state = self._add_eval_message(task_state, solver_result, content=content) + turn += 1 + continue + + # Add user message containing tool outputs + task_state = self._add_tool_outputs(task_state, solver_result, tool_outputs) + turn += 1 + + return self._finish_run(task_state, solver_result, None, turn) + + def _get_tool_names_and_descriptions(self, tools: list[Tool]): + """ + Given sequence of tools, creates a string of each tools name + and description, each tool's info separated by a newline + """ + s = "" + for tool in tools: + s += f"{tool._name}: {tool._desc}\n" + return s + + def _parse_solver_result(self, solver_result: SolverResult) -> ParsedSolverResult: + output = solver_result.output + tool_calls = self._parse_tool_calls(output) + final_answer = self._parse_final_answer(output) + return ParsedSolverResult(tool_calls=tool_calls, final_answer=final_answer) + + def _parse_tool_calls(self, output: str) -> Optional[list[ToolCall]]: + tool_message_matches = self._find_tool_messages(output) + if tool_message_matches == []: + return [] + + tool_calls = [] + for tool_name, tool_message in tool_message_matches: + # Log warning if solver calls a tool that doesn't exist + try: + self.name_to_tool[tool_name] + except KeyError: + logger.warn(f"Solver tried to call '{tool_name}' tool which doesn't exist!") + continue + + tool_call = ToolCall(tool_name=tool_name, input=tool_message, output=None) + tool_calls.append(tool_call) + return tool_calls + + def _find_tool_messages(self, text: str) -> list[tuple[str, str]]: + """ + Finds all tool calls, which are formatted [NAME: INPUT], + where NAME != "Answer" and NAME != "Bugged" + """ + pattern = r"\(@(?!Answer|Bugged)(\w+): (.+?)\)" + matches = re.findall(pattern, text, re.DOTALL) + return matches + + def _parse_final_answer(self, output: str) -> Optional[str]: + """ + If a final answer exists of form [Answer: OUTPUT], returns the output, + otherwise returns None + """ + match = re.search(r"\(@Answer: (.*?)\)", output, re.DOTALL) + return match.group(1) if match else None + + def _run_tool_call(self, tool_call: ToolCall) -> ToolCall: + # Prepare task state + tool_name = tool_call.tool_name + tool = self.name_to_tool[tool_name] + tool_input = tool_call.input + tool_desc = self.name_to_tool[tool_name]._desc + + # Remove quotes if solver wrapped input + if tool_input.startswith(("'", '"')) and tool_input.endswith(("'", '"')): + tool_input = tool_input[1:-1] + + task_description = ( + f"Your name is {tool_name}. A description of your purpose is shown below:\n{tool_desc}" + ) + messages = [Message(role="user", content=tool_input)] + task_state = ToolTaskState( + task_description=task_description, messages=messages, current_state=None + ) + try: + out = tool(task_state) + except (TypeError, ValueError, IndexError): + out = None + + if out is None: + return None + + tool_call.output = out.output + return tool_call + + def _add_eval_message( + self, + task_state: TaskState, + solver_output: SolverResult, + content: str, + ) -> TaskState: + messages = copy.deepcopy(task_state.messages) + messages.append(Message(role="assistant", content=solver_output.output)) + # NOTE: we assume that the order of tool_outputs is the same as the order of tool_calls + + messages.append(Message(role="user", content=content)) + new_task_state = TaskState( + task_description=task_state.task_description, + messages=messages, + current_state=None, + ) + return new_task_state + + def _add_tool_outputs( + self, + task_state: TaskState, + solver_output: SolverResult, + tool_outputs: list[ToolCall], + ) -> TaskState: + content = "" + for tool_output in tool_outputs: + name = tool_output.tool_name + input = tool_output.input + output = tool_output.output + content += f"{name} output on input {input}: {output}\n" + + return self._add_eval_message(task_state, solver_output, content) + + def _finish_run( + self, + final_task_state: TaskState, + solver_result: SolverResult, + final_answer: Optional[str], + turn: int, + ) -> RunnerResult: + expected_answer = self.sample["answer"] + is_correct = False + if final_answer is not None: + final_answer = final_answer.lower().strip() + # Remove quotes if solver wrapped input + if final_answer.startswith(("'", '"')) and final_answer.endswith(("'", '"')): + final_answer = final_answer[1:-1] + is_correct = final_answer == expected_answer.lower().strip() + + metrics = { + "is_correct": is_correct, + "num_turns": turn + 1, # zero-indexed, + } + + return RunnerResult( + final_task_state, + solver_result, + metrics, + ) diff --git a/evals/registry/data/bugged_tools/main.jsonl b/evals/registry/data/bugged_tools/main.jsonl new file mode 100644 index 0000000000..64bc2aa725 --- /dev/null +++ b/evals/registry/data/bugged_tools/main.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:418a355616ca1749a425d2ebce92a6441d62c22858d391887e1aae554db9b724 +size 257364 diff --git a/evals/registry/data/bugged_tools/main_small.jsonl b/evals/registry/data/bugged_tools/main_small.jsonl new file mode 100644 index 0000000000..c2cd4596ca --- /dev/null +++ b/evals/registry/data/bugged_tools/main_small.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d3ef5a5fadf027b15af803dca5d4c61e9abfe27c29f984a76895467de79b30a1 +size 25661 diff --git a/evals/registry/evals/bugged_tools.yaml b/evals/registry/evals/bugged_tools.yaml new file mode 100644 index 0000000000..ff63e87321 --- /dev/null +++ b/evals/registry/evals/bugged_tools.yaml @@ -0,0 +1,31 @@ +bugged_tools: + id: bugged_tools.all + metrics: [f1, precision, recall, accuracy] + description: Evaluates ability to identify bugs in tools + +bugged_tools.all: + class: evals.elsuite.bugged_tools.eval:BuggedTools + args: + samples_jsonl: bugged_tools/main.jsonl + max_turns: 10 + log_all_metrics: False + use_judge: True + bug_instructions_type: simple_warning + +bugged_tools.all_log: + class: evals.elsuite.bugged_tools.eval:BuggedTools + args: + samples_jsonl: bugged_tools/main.jsonl + max_turns: 10 + log_all_metrics: True + use_judge: True + bug_instructions_type: simple_warning + +bugged_tools.all_small: + class: evals.elsuite.bugged_tools.eval:BuggedTools + args: + samples_jsonl: bugged_tools/main_small.jsonl + max_turns: 10 + log_all_metrics: False + use_judge: True + bug_instructions_type: simple_warning