Skip to content

Commit

Permalink
Add function deduction eval
Browse files Browse the repository at this point in the history
Co-authored-by: johny-b <[email protected]>
  • Loading branch information
james-aung and johny-b committed Mar 19, 2024
1 parent 11c30b2 commit f1f02f0
Show file tree
Hide file tree
Showing 13 changed files with 1,609 additions and 0 deletions.
91 changes: 91 additions & 0 deletions evals/elsuite/function_deduction/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Function Deduction

This eval evaluates how well a model can refine a hypothesis according to new evidence and how well it chooses to gather new information.

In Function Deduction:

- There is a secret mathematical function that maps an integer to another integer.
- The evaluated model interacts with the function by picking inputs to run through the function and observing black-box outputs.
- The model’s goal is to correctly predict outputs for a specified set of inputs, which is only possible by working out the underlying logic of the function.

![fd](https://github.com/openai/policy-research-evals/assets/129281094/6c41be74-8237-4bb3-b0fc-13454c20389c)

## Usage

Run with:

```
oaieval <solver> function_deduction
```

We suggest using `function_deduction/cot/gpt-4-32k` or `function_deduction/cot/gpt-3.5-turbo-16k` as default choices for `<solver>`

See `evals/registry/solvers/function_deduction.yaml` for a full list of recommended Solvers.

For more examples of running this eval, see `scripts/run_experiments.sh`

## Dataset

The dataset are 100 int -> int functions of varying difficulty. Outputs of the functions for inputs 0 through 100 (the allowable range) are pre-cached and in the dataset. During the eval run, the eval simply looks up the output for any input the solver is asking about.

## Evaluation Process

Samples from the dataset are evaluated one-by-one.

A model has up to 20 rounds to guess what the function would return for the three test input integers.

In each round the model can either:

- Ask what the function returns for a given integer (other than the test inputs) *or*
- Make a guess for the values that would be returned by the function for the test inputs.

If the model makes a guess and the guess is correct for all 3 test inputs, the model has solved this sample; otherwise, if it guesses wrong, the model is told whether it guessed any subset of the test inputs correctly.

After 20 rounds without a correct guess, the evaluation for that sample ends and the next sample starts. Rounds where the model answered in an incorrect format are still counted towards the 20 round total.

## Task State

| Key | Value |
| --- | --- |
| `task_description` | I have an important task for you. Instructions below.<br><br>I have a deterministic function `def calc(x: int) -> int:` that accepts an integer in range [0, 100] as an argument and returns an integer.<br><br>I want you to correctly determine the result of this function for each of the following inputs: {inputs}<br><br>At each turn you must select one of the following 2 options:<br>Option (A): Gather useful information. Ask me what would be returned for a given input to inform your hypothesis of what the function is. \<br>This must be an integer in range [0, 100], other than {inputs}.<br>Option (B): State your guess for the values returned by the function for inputs {inputs}. Your guess must include all three numbers, in this order. \<br>If your guess is correct, you win. If not, the task will continue.<br><br>You have {n_rounds} turns to determine the correct answer, else you lose.<br><br>Your goal is to determine the correct answer in as few turns as possible, so choose your actions carefully!<br>Your final score will be:<br>* If you make the correct guess: {n_rounds} - number_of_rounds_played<br>* If you fail to make the correct guess: -{n_rounds}<br>You should try to score as high as possible.<br><br>You should respond either with a single integer in range 0-100 if you decide to take option A or with three integers if you make the final guess (i.e. use option B).<br>Don't say anything more except for either an integer or 3 integers. |
| `current_state` | CurrentState object that tracks various data from the current dialog. |

## Metrics

The below are the key metrics of this eval:

| Metric | Interpretation |
| --- | --- |
| `adjusted_avg_score` | Combination metric of the below 2 metrics. The average number of rounds for solved samples, or 40 for not-solved samples. (lower is better) |
| `solved_ratio` | The percentage of solved samples (higher is better) |
| `avg_success_rounds` | The average number of rounds for solved samples (lower is better) |

## Variants

| Variant | Notes |
| --- | --- |
| Default: `function_deduction.easy` | Default setting as described above. 1 trial per sample |
| `function_deduction.easy.long` | 10 trials per sample |
| `function_deduction.easy.dev5` | Dev set with only 5 samples |
| `function_deduction.hard` | A hard variant where the model is only told ‘this guess is incorrect’ if its wrong, instead of being told which inputs it got right/wrong. |
| `function_deduction.hard.dev5` | Dev set with only 5 samples |

## Token Usage Estimates

Below is a rough estimate of the total number of tokens consumed by the default variant:

| Solver | Tokens |
| --- | --- |
| function_deduction/gpt-4-base | 3 840 000 |
| gpt-4-32k | 880 000 |
| gpt-3.5-turbo-16k | 1 560 000 |
| function_deduction/cot/gpt-4-32k | 12 400 000 |
| function_deduction/cot/gpt-3.5-turbo-16k | 13 230 000 |

## Version History

- v0: Initial version released

## Contribution statement

Eval design, implementation, and results evaluation were primarily conducted by Jan Betley with contributions from Andrei Alexandru. Report by James Aung. Work done under the guidance of (alphabetically by last-name) Steven Adler, and Chan Jun Shern, who scoped and managed the broader research project, including input on evaluation design, results analysis, and interpretation.
133 changes: 133 additions & 0 deletions evals/elsuite/function_deduction/baselines.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import logging
import math
from collections import Counter
from pathlib import Path

import numpy as np
from scipy.stats import entropy

from evals.data import get_jsonl
from evals.elsuite.function_deduction.eval import CurrentState, Sample
from evals.registry import Registry
from evals.solvers.solver import Solver, SolverResult
from evals.task_state import TaskState


class AverageBaseline(Solver):
"""
For given test inputs (x, y, z):
* Ask about values of (x-1, x+1, y-1, y+1, z-1, z+1)
* Make three guesses: round/floor/ceil of average values for neighboring numbers
If didn't succeed in 9 rounds (6x ask 3x guess) - surrender.
Note: This algorithm fails on the edge cases where, for any of the inputs i:
- i-1 or i+1 is out of range
- i-1 or i+1 are part of the test inputs
In this scenario, the algorithm will fail at the _get_guess stage and skip the guessing.
"""

def __init__(self, registry=None):
pass

def _solve(self, task_state: TaskState):
cs: CurrentState = task_state.current_state

assert len(cs.test_inputs) == 3, "AverageBaseline assumes 3 test inputs"

if cs.round_ix < 6:
response = self._get_ask(cs.test_inputs, cs.round_ix)
elif 6 <= cs.round_ix < 9:
response = self._get_guess(cs.test_inputs, cs.known_values, cs.round_ix - 6)
else:
response = "I've run out of ideas sorry :("
return SolverResult(response)

def _get_guess(self, test_inputs, known_values: dict[int, int], guess_round_ix) -> str:
known_values = {
x: y for x, y in known_values.items() if x - 1 in test_inputs or x + 1 in test_inputs
}

pairs = [[], [], []]
for i, test_input in enumerate(test_inputs):
try:
lower = known_values[test_input - 1]
higher = known_values[test_input + 1]
except KeyError:
return "Unfortunately I don't have enough data to make a guess, will pass."
pairs[i] = [lower, higher]

funcs = [round, math.floor, math.ceil]
func = funcs[guess_round_ix]
vals = [func((pair[0] + pair[1]) / 2) for pair in pairs]
return " ".join([str(x) for x in vals])

def _get_ask(self, test_inputs, round_ix) -> str:
queries = []
for x in test_inputs:
queries.append(x - 1)
queries.append(x + 1)

ask = queries[round_ix]
if ask in test_inputs or ask < 0 or ask > 100:
logging.warning(
f"Invalid query on inputs {test_inputs}: {ask}. AverageBaseline algorithm will fail."
)
return str(ask)


class FullKnowledge(Solver):
"""Assuming solver knows all the samples, how well would it perform?
Two modes - "random", where it selects random integer when asking,
and "best" where it selects the best integer.
The "best" mode should be close to unbeatable (except for lucky guesses).
"""

def __init__(self, mode: str, samples_jsonl: str, registry: Registry):
assert mode in ("random", "best"), "mode must be either random or best"
self.mode = mode
self._all_samples = self._get_samples(samples_jsonl, registry._registry_paths[0])
self._rng = np.random.default_rng()

def _solve(self, task_state: TaskState):
cs: CurrentState = task_state.current_state

matching_samples = self._get_matching_samples(cs.known_values)
if len(matching_samples) > 1:
if self.mode == "random":
response = self._get_ask_random(cs.known_values)
else:
response = self._get_ask_best(matching_samples)
else:
sample_values = matching_samples[0].values
result = [sample_values[test_input] for test_input in cs.test_inputs]
response = " ".join([str(x) for x in result])
return SolverResult(str(response))

def _get_matching_samples(self, known_values):
def matches(sample: Sample) -> bool:
for key, val in known_values.items():
if sample.values[key] != val:
return False
return True

return [sample for sample in self._all_samples if matches(sample)]

def _get_ask_best(self, samples):
def get_entropy(x: int) -> float:
values = [sample.values[x] for sample in samples]
counter = Counter(values)
return entropy([val for val in counter.values()])

return max(range(0, 101), key=get_entropy)

def _get_ask_random(self, known_values):
while True:
x = self._rng.integers(0, 100)
if x not in known_values:
return x

def _get_samples(self, samples_jsonl: str, registry_path: Path):
path = registry_path / "data" / samples_jsonl
return [Sample(**x) for x in get_jsonl(path.as_posix())]
Loading

0 comments on commit f1f02f0

Please sign in to comment.