-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
1,077 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
# Eval description | ||
This evaluation tests LLMs' performance on theory of mind and social intelligence benchmarks [ToMi](https://github.com/facebookresearch/ToMi) and [SocialIQA](https://allenai.org/data/socialiqa). | ||
|
||
The `ToMi` test set contains 5,993 question-answer pairs. These are instances of the [Sally-Anne test](https://en.wikipedia.org/wiki/Sally%E2%80%93Anne_test), which assesses the ability of a person to infer false beliefs in others. The original setting involves two people, Sally and Anne, who are together in a room. Sally places a marble in a box. Then, Anne leaves the room, and while she is away, Sally moves the marble to a basket elsewhere in the room. When Anne returns to the room, where will she search for the marble? If the person responding “has” theory-of-mind they’ll respond that Anne searches for the marble in the box, where she had last seen it. If they do not, they ascribe their own, accurate belief regarding the location to Anne, and say that she looks for it in the basket. | ||
|
||
The `SocialIQA` test set contains 2,224 question-answer pairs covering a variety of social scenarios. These are multiple-choice, with 3 options of which only one is correct. The questions cover a person’s wants, needs, motivations, and reactions, as well as the effects of an action (on self or others), and how that action reflects on the person carrying it out (e.g. how others would perceive them after having carried out the action). | ||
|
||
Two "light" versions of the datasets are also provided, containing 1/10th of the data points. These are useful for iterating on prompts and developing other scaffolding. | ||
|
||
# Token and pricing estimates | ||
On average: | ||
- On the `SocialIQA` dataset, models consume ~250k tokens per run using the simple solver, and ~900k using the CoT solver. | ||
- On the `ToMi` dataset, models consume ~700k tokens per run using the simple solver, and ~2.4m using the CoT solver. | ||
|
||
To calculate dollar cost from token counts, please check the latest token pricing [here](https://openai.com/pricing). Note that we count both input and output tokens together, so a lower and upper estimate of the cost of each variant can be predicted. | ||
|
||
# Experiments | ||
As a starting point for deeper exploration, we provide scripts for comparing various solvers and eval variants, as well as for plotting the results. To run these: | ||
``` | ||
cd scripts/ | ||
bash run_experiments.sh | ||
``` | ||
|
||
# Contribution statement | ||
Eval design was primarily conducted by Andrei Alexandru, under the guidance of (alphabetically by last-name) Steven Adler, James Aung, Rosie Campbell and Jade Leung who provided research input, report revisions, and project management support. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import json | ||
|
||
# %% | ||
filepath = "/evals/registry/data/theory_of_mind/tomi/train.txt" | ||
|
||
lines, datapoints = [], [] | ||
with open(filepath, "r") as f: | ||
for line in f: | ||
line_index = line.split(" ")[0] | ||
if int(line_index) == 1: | ||
if len(lines) == 0: | ||
lines.append(line) | ||
else: | ||
target = lines[-1].split("\t")[-2] | ||
last_line = lines[-1].split("\t")[0] | ||
lines = [" ".join(line.replace("\n", "").split(" ")[1:]) for line in lines[:-1]] | ||
context = " ".join(lines) + " " + " ".join(last_line.split(" ")[1:]) | ||
datapoints.append({"context": context, "target": target}) | ||
lines = [line] | ||
else: | ||
lines.append(line) | ||
# %% | ||
def convert_datapoints_to_eval_dataset(datapoints: list) -> list: | ||
system_prompt = "You will read a number of sentences describing a situation involving several people, as well as a question regarding the situation. Your task is to answer the question based on the information in the sentences." | ||
eval_dataset = [] | ||
for datapoint in datapoints: | ||
context = datapoint["context"] | ||
target = datapoint["target"] | ||
eval_datapoint = { | ||
"input": [ | ||
{"role": "system", "content": system_prompt}, | ||
{"role": "user", "content": context}, | ||
], | ||
"ideal": target, | ||
} | ||
eval_dataset += [eval_datapoint] | ||
return eval_dataset | ||
|
||
|
||
# %% | ||
eval_dataset = convert_datapoints_to_eval_dataset(datapoints) | ||
# %% | ||
output_file = "tomi_train.jsonl" | ||
|
||
with open(output_file, "w") as out: | ||
for datapoint in eval_dataset: | ||
out.write(json.dumps(datapoint) + "\n") | ||
# %% | ||
filepath = "/evals/registry/data/theory_of_mind/socialiqa/test.jsonl" | ||
system_prompt = "You will read a number of sentences describing a situation, followed by a question regarding the situation. Your task is to answer the question based on the information in the sentences by choosing from one of three answers A, B or C." | ||
|
||
dataset = [] | ||
with open(filepath, "r") as f: | ||
for line in f: | ||
entry = json.loads(line) | ||
template = f"{entry['context']} {entry['question']} A: {entry['answerA']}; B: {entry['answerB']}; C: {entry['answerC']}." | ||
dataset.append( | ||
{ | ||
"input": [ | ||
{"role": "system", "content": system_prompt}, | ||
{"role": "user", "content": template}, | ||
], | ||
"ideal": entry["correct"], | ||
} | ||
) | ||
# %% | ||
output_file = "socialiqa_test.jsonl" | ||
with open(output_file, "w") as out: | ||
for datapoint in dataset: | ||
out.write(json.dumps(datapoint) + "\n") | ||
|
||
# %% | ||
|
||
filepath = "evals/registry/data/theory_of_mind/socialiqa/test.jsonl" | ||
outpath = "evals/registry/data/theory_of_mind/socialiqa/newtest.jsonl" | ||
|
||
dataset = [] | ||
with open(filepath, "r") as f, open(outpath, "w") as out: | ||
for line in f: | ||
entry = json.loads(line) | ||
entry["input"] = [entry["input"][1]] | ||
out.write(json.dumps(entry) + "\n") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
"""Take results from recent experiments and make a bar plot""" | ||
import argparse | ||
from pathlib import Path | ||
from typing import Union | ||
|
||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import pandas as pd | ||
import seaborn as sns | ||
|
||
from evals.utils import log_utils | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--log_dir", type=str, required=True) | ||
parser.add_argument("--out_dir", type=str, required=True) | ||
args = parser.parse_args() | ||
|
||
log_dir = args.log_dir | ||
out_dir = args.out_dir | ||
df = load_tom_results_from_dir(log_dir) | ||
make_plot(df, out_dir=Path(out_dir)) | ||
|
||
|
||
def load_tom_results_from_dir(log_dir: Union[str, Path]) -> pd.DataFrame: | ||
rows = [] | ||
final_results_dict = log_utils.get_final_results_from_dir(log_dir) | ||
|
||
for path, final_results in final_results_dict.items(): | ||
spec = log_utils.extract_spec(path) | ||
dataset, prompt_type, model = parse_spec(spec) | ||
rows.append( | ||
{ | ||
"model": model, | ||
"dataset": dataset, | ||
"prompt_type": prompt_type, | ||
"accuracy": final_results["accuracy"], | ||
"bootstrap_std": final_results["bootstrap_std"], | ||
} | ||
) | ||
return pd.DataFrame(rows) | ||
|
||
|
||
def parse_spec(spec: dict) -> tuple[str, bool, int]: | ||
"""parse the spec from a MMP run""" | ||
completion_fn = spec["completion_fns"][0] | ||
dataset, prompt_type, model = completion_fn.split("/") | ||
prompt_type = prompt_type.split("_")[0] | ||
|
||
return (dataset, prompt_type, model) | ||
|
||
|
||
def make_plot(df, out_dir): | ||
sns.set_theme(style="whitegrid") | ||
sns.set_palette("dark") | ||
# Define the order of models | ||
model_order = ["gpt-3.5-turbo", "gpt-4-base", "gpt-4"] | ||
datasets = df["dataset"].unique() | ||
|
||
for dataset in datasets: | ||
ds = df[df["dataset"] == dataset.lower()] | ||
|
||
# Ensure the model column is a categorical type with the specified order | ||
ds["model"] = pd.Categorical(ds["model"], categories=model_order, ordered=True) | ||
ds = ds.sort_values("model") # Sort according to the categorical order | ||
|
||
# Unique models | ||
xs = ds["model"].unique() | ||
# Get the accuracy values for both prompt types | ||
simple_acc = ds[ds["prompt_type"] == "simple"]["accuracy"].values | ||
cot_acc = ds[ds["prompt_type"] == "cot"]["accuracy"].values | ||
|
||
# Get the corresponding error values from the "bootstrap_std" field | ||
simple_std = ds[ds["prompt_type"] == "simple"]["bootstrap_std"].values | ||
cot_std = ds[ds["prompt_type"] == "cot"]["bootstrap_std"].values | ||
|
||
# Define the width of a bar | ||
bar_width = 0.35 | ||
# Set the positions of the bars | ||
x_indices = np.arange(len(xs)) | ||
x_indices2 = [x + bar_width for x in x_indices] | ||
|
||
fig, ax1 = plt.subplots() | ||
fig.suptitle(f"Accuracy on {dataset} dataset") | ||
|
||
ax1.set_xlabel("Model") | ||
ax1.set_ylabel("Accuracy") | ||
|
||
# Plot the bars for 'simple' and 'cot' | ||
ax1.bar( | ||
x_indices, | ||
simple_acc, | ||
width=bar_width, | ||
color=sns.color_palette("pastel")[0], | ||
yerr=simple_std, | ||
label="simple", | ||
) | ||
ax1.bar( | ||
x_indices2, | ||
cot_acc, | ||
width=bar_width, | ||
color=sns.color_palette("pastel")[1], | ||
yerr=cot_std, | ||
label="chain-of-thought", | ||
) | ||
|
||
if dataset == "socialiqa": | ||
# Draw the horizontal line for the human baseline | ||
human_baseline = 0.881 | ||
ax1.axhline(y=human_baseline, color="gray", linestyle="--", linewidth=1) | ||
# Add the text label for the human baseline | ||
ax1.text( | ||
0.01, human_baseline, "human baseline", va="center", ha="left", backgroundcolor="w" | ||
) | ||
|
||
# Set the x-axis ticks to be in the middle of the two bars | ||
ax1.set_xticks([r + bar_width / 2 for r in range(len(xs))]) | ||
ax1.set_xticklabels(xs, rotation=45) # Rotate the x-axis labels if needed | ||
|
||
ax1.set_ylim(0, 1) | ||
|
||
# Add legend | ||
ax1.legend(loc="upper right", bbox_to_anchor=(1, 1)) | ||
|
||
# Save the figure | ||
plt.savefig(out_dir / f"accuracy_{dataset.lower()}.png", bbox_inches="tight") | ||
plt.tight_layout() # Adjust the plot to ensure everything fits without overlapping | ||
plt.show() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
logdir=./logs | ||
outputdir=./outputs | ||
timestamp=$(date +%Y%m%d_%H%M%S) | ||
logpathbase="$logdir/$timestamp/" | ||
|
||
echo Running experiments and logging to $logpathbase | ||
|
||
DATASETS="tomi socialiqa hitom" | ||
MODELS="gpt-3.5-turbo gpt-4 gpt-4-base" | ||
SOLVER_TYPES="simple_solver cot_solver" | ||
|
||
for dataset in $DATASETS | ||
do | ||
for model in $MODELS | ||
do | ||
for solver in $SOLVER_TYPES | ||
do | ||
oaieval $dataset/$solver/$model "theory_of_mind."$dataset --record_path "$logpathbase/$model-$variant.log" | ||
done | ||
done | ||
done | ||
|
||
echo Done running experiments, all logs in $logpathbase | ||
|
||
echo Producing plots, outputs to $outputdir | ||
python3 make_plots.py --log_dir $logpathbase --out_dir $outputdir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
tomi/test.jsonl filter=lfs diff=lfs merge=lfs -text | ||
tomi/test_light.jsonl filter=lfs diff=lfs merge=lfs -text | ||
socialiqa/test.jsonl filter=lfs diff=lfs merge=lfs -text | ||
socialiqa/test_light.jsonl filter=lfs diff=lfs merge=lfs -text |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
ToMi: | ||
License: Creative Commons Attribution-NonCommercial 4.0 International (CC-BY-NC 4.0) https://creativecommons.org/licenses/by-nc/4.0/legalcode.en | ||
Source: https://github.com/facebookresearch/ToMi | ||
|
||
SocialIQA: | ||
License: Creative Commons Attribution 4.0 International (CC-BY 4.0) https://creativecommons.org/licenses/by/4.0/legalcode.en | ||
Source: https://allenai.org/data/socialiqa |
Oops, something went wrong.