diff --git a/evals/elsuite/identifying_variables/.gitattributes b/evals/elsuite/identifying_variables/.gitattributes new file mode 100644 index 0000000000..e256da66cb --- /dev/null +++ b/evals/elsuite/identifying_variables/.gitattributes @@ -0,0 +1 @@ +images/*.png filter=lfs diff=lfs merge=lfs -text diff --git a/evals/elsuite/identifying_variables/README.md b/evals/elsuite/identifying_variables/README.md new file mode 100644 index 0000000000..59912f0b27 --- /dev/null +++ b/evals/elsuite/identifying_variables/README.md @@ -0,0 +1,177 @@ +# Identifying Variables + +This eval tests how well models can determine what should be treated as the +independent, dependent, and control variables for an experiment that tests a +particular hypothesis, given some observational context. + +## Usage + +Run with: + +```bash +oaieval identifying_variables +``` + +We have found that `generation/cot/gpt-4-1106-preview` works well on this eval. For more examples of tested solvers, see [`./scripts/run_experiments.sh`](./scripts/run_experiments.sh). + +## Evaluation Process + +The evaluation process is as follows for a given sample from our dataset: + +1. The `TASK_DESCRIPTION` prompt is shown to the solver. +2. The sample is passed through a _renderer_ that processes the samples and + renders an observation of the interactions of variables, which is placed in + the `SAMPLE_MESSAGE` prompt template. +3. The solver answers in the form: `[@ANSWER valid_hyp: ; independent: ; dependent: ; control: ]`. The answer is parsed and evaluated by the eval. If the answer cannot be parsed, we mark this as a violation and the sample is treated as incorrect. + +## Prompts + +We refer readers to the [`./prompts.py`](./prompts.py) file for the +`TASK_DESCRIPTION` and `SAMPLE_MESSAGE` prompts used in the eval. + +## Metrics + + +| **Metric** | **Notes** | +|---|---| +| `ctrl_nDCG` | A modified version of the [normalized discounted cumulative gains (nDCG)](https://en.wikipedia.org/wiki/Discounted_cumulative_gain#Normalized_DCG) metric, which rewards listing the correct control variables first and penalizes naming irrelevant variables. | +| `ctrl_recall` | Number of variables correctly marked as control variables / total number of variables to control according to the gold label | +| `ctrl_recall` | Number of variables incorrectly marked as control variables / total number of variables not to control according to the gold label | +| `hyp_valid_acc` | Target hypothesis plausibility validation accuracy (correct/incorrect) | +| `ind_acc` | Independent variable determination accuracy (correct/incorrect) | +| `dep_acc` | Dependent variable determination accuracy (correct/incorrect) | +| `violation_rate` | Number of samples with violations (model failed to answer in correct format) / total number of samples | + + +## Variants + +We support variations on the eval along two dimensions, `renderer` and `dataset`: + +```bash +oaieval identifying_variables.. +``` + +The eval defaults to `identifying_variables.language-corrset.balanced-ctrl`. + +### Dataset + +We provide 4 dataset variants: + +| `dataset` | Notes | +| --- | --- | +| `balanced-ctrl` | 500 samples balanced across number of control variables (from 0 to 8). | +| `balanced-ctrl-large` | As `balanced-ctrl`, but with 5,000 samples. | +| `balanced-hypotheses` | 500 samples balanced across target hypotheses being implausible/plausible. | +| `balanced-hypotheses-large` | As `balanced-hypotheses`, but with 5,000 samples. | + +### Renderers + +We have 6 different renderers, implemented in [`./renderers/`](./renderers/). + +The default renderer is `language-corrset`. Here is an example render from this type: +``` +The following is a description of the observations made about a set of variables. + +In general, there were cases where some variables changed in tandem with each other, while others did not. +For example, changes in x_5075 were observed to reflect changes in x_3314 and viceversa. +Changes in x_9549 were not observed to reflect any changes in previously mentioned variables. +Changes in x_1808 were not observed to reflect any changes in previously mentioned variables. +Likewise, changes in x_9726 were observed to reflect changes in x_1808 and viceversa. +``` + +### Show Tree + +We provide an additional variant of the eval where the decision tree implementing +the reasoning for scoring a perfect score is shown to the model. This variant +can be run by passing the `show_tree=True` flag to eval, e.g. + +```bash +oaieval identifying_variables --extra_eval_params show_tree=True +``` + +## Custom Solvers + +We implement two custom programmatic solvers to serve as baselines. + +1. `identifying_variables/random`: a solver that randomly selects whether the + hypothesis is plausible with probability 0.5, and if so randomly samples the + independent, dependent and control variables. We view this baseline as + equivalent to randomly guessing. +2. `identifying_variables/noctrl`: this is a solver that always outputs an empty + list for the variables to control, essentially eliminating any chance of + false positives. This can provide stronger performance than the random + baseline, since it avoids any penalization for returning incorrect variables, + and can even achieve a perfect score on samples that indeed do not have any + variables to control + +We refer to [`./solvers.py`](./solvers.py) for the implementation of these +solvers. + +## Token Usage Estimates + +We estimated per-run token usage on the default dataset size (500 samples) +for the least and most token-intensive configurations for each model type +(respectively, direct models on `identifying_variables.corrset` with +`show_tree=False`; and CoT models on `identifying_variables.language-tabular` +with `show_tree=True`). + + +| | **input tokens/run** | **output tokens/run** | **total tokens/run** | +|---|---|---|---| +| **GPT-4-base HHH (corrset, no tree)** | 1,200,000 | 250,000 | 1,450,000 | +| **GPT-4-base CoT HHH (language-tabular, with tree)** | 1,500,000 | 240,000 | 1,740,000 | +| **GPT-3.5-turbo Direct (corrset, no tree)** | 430,000 | 88,000 | 518,000 | +| **GPT-3.5-turbo CoT (language-tabular, with tree)** | 780,000 | 14,000 | 794,000 | +| **GPT-4-1106-preview Direct (corrset, no tree)** | 430,000 | 53,000 | 483,000 | +| **GPT-4-1106-preview CoT (language-tabular, with tree)** | 860,000 | 14,000 | 874,000 | + + +These estimates were taken using the `balanced-hypotheses` dataset but should +roughly apply to the `-balanced-ctrl` datasets. For `-large` datasets (5000 +samples), multiply the above numbers by 10. + +## Future modifications + +- Revisit the definition of the headline `ctrl_nDCG` metric +- Devise additional auxiliary metrics to paint a more complete picture +- What if we show the decision trees described in natural language rather than + pseudocode? +- How can we extend this eval to multi-variable dependencies? + +## Version History + +- v0: Initial version released + +## Contribution Statement + +Eval design, implementation, and results evaluation and writeup were primarily +conducted by Giulio Starace. James Aung was of enormous assistance in report +writing, and additionally provided general feedback and project management +throughout the eval. Oliver Jaffe and Jan Betley were responsible for code +reviews throughout the implementation process, along with fine-grained feedback +on the project in general. Additional guidance was provided by (alphabetically +by last-name) Steven Adler and Chan Jun Shern, who helped with brainstorming, +gave research input and report revisions. + +## Appendix + +### Perfect output decision trees + +The following decision trees are used to determine the perfect output (aka "gold +label") for a given sample. + +--- + + + +**Figure A1**: Decision tree for determining whether a given variable should be +controlled. + +--- + + + +**Figure A2**: Decision tree for determining a hypothesis is valid and if so +what the independent and dependent variables are. + +--- diff --git a/evals/elsuite/identifying_variables/constants.py b/evals/elsuite/identifying_variables/constants.py new file mode 100644 index 0000000000..60729828c7 --- /dev/null +++ b/evals/elsuite/identifying_variables/constants.py @@ -0,0 +1,19 @@ +# variables that have at least this amount of sparsity are considered to be unobserved +SPARSITY_FOR_UNOBS = 0.8 +# num of variables in a given sample +MIN_VARS = 2 +MAX_VARS = 10 +# num of hypotheses in a given sample +MIN_HYPS = 1 +MAX_HYPS = 3 +# sparse var rate: percentage of variables to sparsify +MIN_SPARSE_VAR_RATE = 0 +MAX_SPARSE_VAR_RATE = 1 +# sparsity: percentage of NaNs in a sparsified variable +MIN_SPARSITY = 0.2 +MAX_SPARSITY = 1 + +# specific to tabular renderers ------------ + +# num of observations +NUM_OBS = 20 diff --git a/evals/elsuite/identifying_variables/eval.py b/evals/elsuite/identifying_variables/eval.py new file mode 100644 index 0000000000..31b3b743e0 --- /dev/null +++ b/evals/elsuite/identifying_variables/eval.py @@ -0,0 +1,292 @@ +""" +Implementation logic for Identifying Variables eval +""" +import logging +import random +from dataclasses import asdict +from typing import Dict, List, Optional, Tuple + +import networkx as nx +import numpy as np + +from evals.elsuite.identifying_variables import constants, graph_utils, prompts +from evals.elsuite.identifying_variables.metrics import ( + compute_fallout, + compute_nDCG, + compute_recall, +) +from evals.elsuite.identifying_variables.renderers import RENDERER_MAP +from evals.elsuite.identifying_variables.scripts.gen_data import gen_samples +from evals.elsuite.identifying_variables.structs import Answer, Sample +from evals.elsuite.identifying_variables.utils import json_to_sample, parse_solver_preds +from evals.eval import SolverEval +from evals.record import RecorderBase, record_metrics +from evals.solvers.solver import Solver, SolverResult +from evals.task_state import Message, TaskState + +logging.getLogger("httpx").setLevel(logging.WARNING) + + +class IdentifyingVariables(SolverEval): + def __init__( + self, + renderer: str, + n_samples: Optional[int] = None, + show_tree: bool = False, + group_metrics: bool = False, + debug: bool = False, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.rng: random.Random = random.Random(self.seed) + self.np_rng: np.random.Generator = np.random.default_rng(self.seed) + self.renderer = RENDERER_MAP[renderer](rng=self.rng, np_rng=self.np_rng) + self.renderer_variant = renderer + self.n_samples = n_samples + self.show_tree = show_tree + self.task_description = self._build_task_description() + self.group_metrics = group_metrics + self.debug = debug + + def _build_task_description(self) -> str: + decision_tree_section = "" + if self.show_tree: + decision_tree_section = prompts.DECISION_TREE_SECTION + return prompts.TASK_DESCRIPTION.format( + optional_decision_tree_section=decision_tree_section, + ).strip() + + def eval_sample(self, solver: Solver, sample: Sample, rng: random.Random) -> None: + message: Message = self._build_message(sample) + + task_state = TaskState( + task_description=self.task_description, + messages=[message], + # to be used by the Random baseline solver only + current_state={"variables": [var for var in sample.causal_graph.nodes]}, + ) + + solver_result: SolverResult = solver(task_state) + + try: + preds = parse_solver_preds(solver_result) + except ValueError: # in case of invalid solver output + preds = None + gold, num_not_ctrl = sample.gold_label, sample.num_not_ctrl + + metrics: Dict[str, float] = self._evaluate_sample(preds, gold, num_not_ctrl) + + record_metrics( + **metrics, + # hack: logviz doesn't support custom log fields, so logging as metric + causal_graph=nx.to_dict_of_lists(sample.causal_graph), + gold_answer=asdict(gold), + n_hyps=sample.hypotheses.number_of_edges(), + valid_hyp=gold.valid_hypothesis, + num_not_ctrl=num_not_ctrl, + ) + + def run(self, recorder: RecorderBase) -> Dict[str, float]: + samples: List[Dict] = self._get_samples() + self.rng.shuffle(samples) + self.eval_all_samples(recorder, samples) + metrics: List[Dict] = recorder.get_metrics() + + return self._compute_agg_metrics(metrics) + + def _compute_agg_metrics(self, metrics: List[Dict]) -> Dict[str, float]: + """ + Computes aggregate metrics across all samples + """ + main_metrics = { + "hyp_valid_acc": np.mean([x["hyp_valid_correct"] for x in metrics]), + "violation_count": np.sum([x["violation"] for x in metrics]), + "violation_rate": np.mean([x["violation"] for x in metrics]), + # Some samples may be NaN for cases where the target hypothesis is invalid + "ctrl_nDCG": np.nanmean([x["ctrl_nDCG"] for x in metrics]), + "ctrl_recall": np.nanmean([x["ctrl_recall"] for x in metrics]), + "ctrl_fallout": np.nanmean([x["ctrl_fallout"] for x in metrics]), + "ind_acc": np.nanmean([x["ind_correct"] for x in metrics]), + "dep_acc": np.nanmean([x["dep_correct"] for x in metrics]), + "n_valid_hyp": np.sum([x["valid_hyp"] for x in metrics]), + } + if self.group_metrics: + grouped_metrics = self._compute_grouped_metrics(metrics) + else: + grouped_metrics = {} + + total_metrics = {**main_metrics, **grouped_metrics} + total_metrics = {k: float(v) for k, v in total_metrics.items()} + return total_metrics + + def _compute_grouped_metrics(self, metrics: List[Dict]) -> Dict[str, float]: + """ + Computes metrics aggregated across samples grouped by + - number of variables + - number of roots in random forest + - number of control variables + - number of hypotheses + - max correlation depth + """ + metric_to_agg_func = { + "hyp_valid_acc": np.mean, + "violation_count": np.sum, + "violation_rate": np.mean, + "ctrl_nDCG": np.nanmean, + "ctrl_recall": np.nanmean, + "ctrl_fallout": np.nanmean, + "ind_acc": np.nanmean, + "dep_acc": np.nanmean, + } + raw_metric_names = [ + "hyp_valid_correct", + "violation", + "violation", + "ctrl_nDCG", + "ctrl_recall", + "ctrl_fallout", + "ind_correct", + "dep_correct", + ] + group_to_bins = { + "n_vars": np.arange(constants.MIN_VARS, constants.MAX_VARS + 1), + "n_roots": np.arange(1, constants.MAX_VARS + 1), + "n_ctrl_vars": np.arange(0, (constants.MAX_VARS - 2) + 1), + "n_hyps": np.arange(constants.MIN_HYPS, constants.MAX_HYPS + 1), + "max_corr_depth": np.arange(1, constants.MAX_VARS), + } + grouped_metrics = { + f"{metric}-{group}-{g_bin}": [] + for metric in metric_to_agg_func.keys() + for group in group_to_bins.keys() + for g_bin in group_to_bins[group] + } + for log_entry in metrics: + causal_graph = nx.from_dict_of_lists(log_entry["causal_graph"], create_using=nx.DiGraph) + ctrl_vars = log_entry["gold_answer"]["ctrl_vars"] + dep_var = log_entry["gold_answer"]["dep_var"] + group_to_bin = { + "n_vars": causal_graph.number_of_nodes(), + "n_roots": len(graph_utils.find_graph_roots(causal_graph)), + "n_ctrl_vars": len(ctrl_vars) if ctrl_vars is not None else None, + "n_hyps": log_entry["n_hyps"], + "max_corr_depth": graph_utils.find_farthest_node(causal_graph, dep_var)[1] + if dep_var is not None + else None, + } + for group, g_bin in group_to_bin.items(): + if g_bin is not None: + for metric, raw_metric in zip(metric_to_agg_func.keys(), raw_metric_names): + grouped_metrics[f"{metric}-{group}-{g_bin}"].append(log_entry[raw_metric]) + + # aggregate + grouped_metrics = { + k: metric_to_agg_func[k.split("-")[0]](v) + # signal empty groups with np.nan + if len(v) > 0 else np.nan + for k, v in grouped_metrics.items() + } + return grouped_metrics + + def _evaluate_sample(self, preds: Optional[Answer], gold: Answer, num_not_ctrl: int) -> Dict: + """ + If the gold hypothesis is invalid, then all other metrics are skipped, and we + only evaluate whether the solver correctly identified the hypothesis as invalid. + + Mistakes are propagated: If the solver incorrectly identifies a hypothesis as + invalid, then its missing answers for the remaining tasks are counted as wrong. + + In case of violations, the worst possible metrics are recorded, accounting for + the gold hypothesis validity caveat above (e.g. if the gold hypothesis is + invalid, then the worst case ctrl_nDCG is NaN since we'd skip this anyway, + whereas if the gold hypothesis were valid, then the worst case ctrl_nDCG would + be 0.0) + """ + hyp_valid_correct = preds.valid_hypothesis == gold.valid_hypothesis if preds else False + + if gold.valid_hypothesis: + ind_correct = preds.ind_var == gold.ind_var if preds else False + dep_correct = preds.dep_var == gold.dep_var if preds else False + ctrl_nDCG = ( + self._ctrl_vars_nDCG(preds.ctrl_vars, gold.ctrl_vars, num_not_ctrl) + if preds and preds.ctrl_vars is not None + else 0.0 + ) + ctrl_recall = ( + self._ctrl_vars_recall(preds.ctrl_vars, gold.ctrl_vars) + if preds and preds.ctrl_vars is not None + else 0.0 + ) + # not in final report, since experiments had already been run + ctrl_fallout = ( + self._ctrl_vars_fallout(preds.ctrl_vars, gold.ctrl_vars, num_not_ctrl) + if preds and preds.ctrl_vars is not None + else 1.0 + ) + + else: + ctrl_nDCG = np.nan + ctrl_recall = np.nan + ctrl_fallout = np.nan + ind_correct = np.nan + dep_correct = np.nan + + return { + "ctrl_nDCG": ctrl_nDCG, + "ctrl_recall": ctrl_recall, + "ctrl_fallout": ctrl_fallout, + "ind_correct": ind_correct, + "dep_correct": dep_correct, + "hyp_valid_correct": hyp_valid_correct, + "violation": preds is None, + } + + def _ctrl_vars_fallout(self, preds: List[str], gold: List[str], num_not_ctrl: int) -> float: + return compute_fallout(set(preds), set(gold), num_not_ctrl) + + def _ctrl_vars_recall(self, preds: List[str], gold: List[str]) -> float: + return compute_recall(set(preds), set(gold)) + + def _ctrl_vars_nDCG(self, preds: List[str], gold: List[str], num_not_ctrl: int) -> float: + best = [1.0] * len(gold) + ranking = [1.0 if var in gold else -1.0 for var in preds] + worst_case_ctrl = [-1.0] * num_not_ctrl + return compute_nDCG(ranking, best, worst_case_ctrl) + + def _build_message(self, sample: Sample) -> Message: + observations: str = self.renderer.render_obs(sample) + hypotheses: List[str] = self._render_hypotheses(sample.hypotheses) + target_hypothesis: str = self._render_hypothesis(sample.target_hypothesis) + + message_content = prompts.SAMPLE_MESSAGE.format( + observations=observations, + hypotheses=hypotheses, + target_hypothesis=target_hypothesis, + ).strip() + message = Message("user", content=message_content) + + return message + + def _render_hypotheses(self, hypotheses: nx.DiGraph) -> List[str]: + hyp_list = [(n, adj) for n in hypotheses for adj in hypotheses[n]] + return [self._render_hypothesis(h) for h in hyp_list] + + def _render_hypothesis(self, hypothesis: Tuple[str, str]) -> str: + hyp_template = self.rng.choice(prompts.hypothesis_templates) + rendered_hyp = hyp_template.format(ind=hypothesis[0], dep=hypothesis[1]) + return rendered_hyp + + def _get_samples(self) -> List[Sample]: + if self.debug: + return gen_samples(n_samples=1000, signal_noise_ratio=None, np_rng=self.np_rng) + + dict_samples = self.get_samples() + if self.n_samples is not None: + assert ( + len(dict_samples) >= self.n_samples + ), f"Can't get {self.n_samples} samples from a dataset with {len(dict_samples)} samples" + np.random.default_rng(seed=self.seed).shuffle(dict_samples) + dict_samples = dict_samples[: self.n_samples] + samples = [json_to_sample(dict_sample) for dict_sample in dict_samples] + return samples diff --git a/evals/elsuite/identifying_variables/graph_utils.py b/evals/elsuite/identifying_variables/graph_utils.py new file mode 100644 index 0000000000..815ab968cc --- /dev/null +++ b/evals/elsuite/identifying_variables/graph_utils.py @@ -0,0 +1,254 @@ +"""Utils for network graph related operations.""" +from typing import Any, List, Optional, Set, Tuple, Union + +import networkx as nx +import numpy as np + + +def val_and_count_roots( + nodes: List[str], + np_rng: np.random.Generator, + total_edges: Optional[int] = None, + min_roots: Optional[int] = None, +) -> int: + """ + Validates the parameters for the construction of a random forest via + `gen_random_forest` and determines the min number of roots to use. + + A random forest following the constraints of `gen_random_forest` with + N nodes will have + - R <= N roots + - E <= N - R edges + If min_roots is not specified, then E <= N - 1, since R >= 1. + """ + n_nodes = len(nodes) + if min_roots is not None: + assert min_roots <= n_nodes, "Total roots must be less than or equal to the number of nodes" + if total_edges is not None: + assert ( + 0 <= total_edges <= n_nodes - min_roots + ), "Total edges must be between 0 and the number of nodes minus the number of roots" + else: + if total_edges is None: + min_roots = np_rng.integers(1, n_nodes + 1) + else: + assert ( + 0 <= total_edges <= n_nodes - 1 + ), "Total edges must be between 0 and the number of nodes minus 1" + # if total edges is specified, then we have an upper bound on R, R <= N - E + max_roots = n_nodes - total_edges + min_roots = np_rng.integers(1, max_roots + 1) + + return min_roots + + +def gen_random_forest_tree_size( + nodes: List[str], + tree_size: int, + np_rng: Optional[np.random.Generator] = None, +) -> nx.DiGraph: + """ + Builds a random forest, i.e. a Directed Acyclic Graph (DAG) + with potentially more than one root. + + We enforce the following constraints for our purposes: + 1. No self connections + 2. No bi-directional connections + 3. No children with multiple parents + 4. At least one root node (no parents) + 5. No cycles + + We additionally allow the user to specify the size that at least one + of the trees in the forest should be. + + Args: + nodes: A list of node names to build the graph from + tree_size: The number of nodes that at least one of the trees in the forest + should have + np_rng: A numpy random number generator + """ + num_nodes = len(nodes) + assert tree_size <= num_nodes, "Tree size must be less than or equal to the number of nodes" + + max_number_roots = num_nodes - tree_size + 1 + min_number_roots = 1 # 1 root is always reserved to the tree of size tree_size + + np_rng = np_rng or np.random.default_rng() + + num_roots = np_rng.integers(min_number_roots, max_number_roots + 1) + roots = set(np_rng.choice(nodes, num_roots, replace=False).tolist()) + + size_controlled_root = np_rng.choice(list(roots)) + size_controlled_tree_nodes = {size_controlled_root} + + shuffled_nodes = np_rng.permutation(nodes) + + graph_children = set() + + graph = nx.DiGraph() + graph.add_nodes_from(shuffled_nodes) + + while len(size_controlled_tree_nodes) < tree_size: + possible_children = [ + n for n in nodes if n not in size_controlled_tree_nodes and n not in roots + ] + child = np_rng.choice(possible_children) + possible_parents = list(size_controlled_tree_nodes) + parent = np_rng.choice(possible_parents) + graph.add_edge(parent, child) + size_controlled_tree_nodes.add(child) + graph_children.add(child) + + remaining_nodes = set(nodes) - size_controlled_tree_nodes + + for node in remaining_nodes: + possible_children = [ + n + for n in remaining_nodes + # avoid self connections + if n != node and + # avoid cycles and bi-directional conns -> ancestors can't be children + n not in nx.ancestors(graph, node) and + # avoid children with multiple parents + n not in graph_children and + # roots can't be children + n not in roots + ] + num_edges = np_rng.integers(0, len(possible_children) + 1) + children = np_rng.choice(possible_children, num_edges, replace=False).tolist() + + for child in children: + graph.add_edge(node, child) + graph_children.update(children) + + return graph + + +def gen_random_forest( + nodes: List[str], + total_edges: Optional[int] = None, + min_roots: Optional[int] = None, + np_rng: Optional[np.random.Generator] = None, +) -> nx.DiGraph: + """ + Builds a random forest, i.e. a Directed Acyclic Graph (DAG) + with potentially more than one root. + + We enforce the following constraints for our purposes: + 1. No self connections + 2. No bi-directional connections + 3. No children with multiple parents + 4. At least one root node (no parents) + 5. No cycles + + Args: + nodes: A list of node names to build the graph from + total_edges: The total number of edges in the graph. If None, will be random. + min_roots: The minimum number of roots in the graph. If None, will be random. + """ + np_rng = np_rng or np.random.default_rng() + graph = nx.DiGraph() + graph.add_nodes_from(nodes) + + min_roots = val_and_count_roots(nodes, np_rng, total_edges, min_roots) + + # the minimal set of roots, there may be more as we create the graph + roots = set(np_rng.choice(nodes, min_roots, replace=False).tolist()) + + graph_children = set() + edge_count = 0 + + shuffled_nodes = np_rng.permutation(nodes) + + for node in shuffled_nodes: + possible_children = [ + n + for n in nodes + # avoid self connections + if n != node and + # avoid cycles and bi-directional conns -> ancestors can't be children + n not in nx.ancestors(graph, node) and + # avoid children with multiple parents + n not in graph_children and + # roots can't be children + n not in roots + ] + + if len(possible_children) == 0: + continue + + if total_edges is not None: + remaining_edges = total_edges - edge_count + if remaining_edges <= 0: + break + num_edges = np_rng.integers(0, min(remaining_edges, len(possible_children)) + 1) + else: + num_edges = np_rng.integers(0, len(possible_children) + 1) + + children = np_rng.choice(possible_children, num_edges, replace=False).tolist() + + for child in children: + graph.add_edge(node, child) + graph_children.update(children) + edge_count += num_edges + + if total_edges is not None and edge_count < total_edges: + # If we didn't reach the total number of edges, try again + return gen_random_forest(nodes, total_edges, min_roots, np_rng) + + return graph + + +def find_farthest_node(graph: nx.DiGraph, source: str) -> Tuple[str, int]: + """ + Performs Breadth-First Search (BFS) to find the farthest node from the source node + and the distance to that node. Distance is defined as the number of edges between + the source node and the farthest node. + """ + graph = graph.to_undirected() + + # Compute shortest path lengths from source to all other nodes + path_lengths = nx.single_source_shortest_path_length(graph, source) + + # Find the farthest node + farthest_node = max(path_lengths, key=path_lengths.get) + max_distance = path_lengths[farthest_node] + + return farthest_node, max_distance + + +def find_graph_roots(graph: nx.DiGraph) -> Set[str]: + """ + Finds the root nodes of a graph + """ + return set([n for n, d in graph.in_degree() if d == 0]) + + +def find_graph_trees(graph: nx.DiGraph) -> List[Set[str]]: + """ + Finds the trees of a graph + """ + return [{root, *nx.descendants(graph, root)} for root in find_graph_roots(graph)] + + +def find_connected_nodes_pair( + graph: nx.DiGraph, np_rng: np.random.Generator +) -> Union[Tuple[Any, Any], None]: + """ + Finds a pair of connected nodes in a graph + If no such pair exists, returns None + """ + connected_pair = tuple(np_rng.choice(list(graph.edges))) if graph.edges else None + return connected_pair + + +def find_unconnected_nodes_pair(graph: nx.DiGraph) -> Union[Tuple[Any, Any], None]: + """ + Finds a pair of unconnected nodes in a graph + If no such pair exists, returns None + """ + components = list(nx.connected_components(graph.to_undirected())) + + if len(components) > 1: + return next(iter(components[0])), next(iter(components[1])) + return None diff --git a/evals/elsuite/identifying_variables/images/control_var_tree.png b/evals/elsuite/identifying_variables/images/control_var_tree.png new file mode 100755 index 0000000000..59de243e29 --- /dev/null +++ b/evals/elsuite/identifying_variables/images/control_var_tree.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:60bbedac103bae669c4cec1037faaa18b87df63ab5d2c61734f2c60211240fd6 +size 273556 diff --git a/evals/elsuite/identifying_variables/images/valid_hyp_tree.png b/evals/elsuite/identifying_variables/images/valid_hyp_tree.png new file mode 100644 index 0000000000..d005e47b47 --- /dev/null +++ b/evals/elsuite/identifying_variables/images/valid_hyp_tree.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:758a23f6b4bd7676852af320b28f8b6af61c404d22835eda99f2b8dc89a0277b +size 69394 diff --git a/evals/elsuite/identifying_variables/latent_funcs.py b/evals/elsuite/identifying_variables/latent_funcs.py new file mode 100644 index 0000000000..6f66a1c44e --- /dev/null +++ b/evals/elsuite/identifying_variables/latent_funcs.py @@ -0,0 +1,43 @@ +"""Latent functions for the project.""" +import numpy as np + + +def linear(x: np.ndarray, grad: float, bias: float) -> np.ndarray: + return grad * x + bias + + +def quadratic(x: np.ndarray, grad: float, bias: float) -> np.ndarray: + return grad * x**2 + bias + + +def random_uniform(num_samples, min_v, max_v, rng: np.random.Generator) -> np.ndarray: + return rng.uniform(min_v, max_v, num_samples) + + +def random_ints(num_samples, min_v, max_v, rng: np.random.Generator) -> np.ndarray: + return rng.integers(min_v, max_v, num_samples) + + +LATENT_FUNC_MAP = { + "linear": linear, + "quadratic": quadratic, +} +LATENT_FUNC_KWARG_MAP = { + "linear": { + "grad": {"min_v": -10, "max_v": 10}, + "bias": {"min_v": -100, "max_v": 100}, + }, + "quadratic": { + "grad": {"min_v": -10, "max_v": 10}, + "bias": {"min_v": -100, "max_v": 100}, + }, +} + +DISTRIBUTIONS = { + # "random_uniform": random_uniform, + "random_ints": random_ints, +} +DISTRIBUTIONS_KWARG_MAP = { + "random_uniform": {"min_v": -1, "max_v": 1}, + "random_ints": {"min_v": -100, "max_v": 100}, +} diff --git a/evals/elsuite/identifying_variables/metrics.py b/evals/elsuite/identifying_variables/metrics.py new file mode 100644 index 0000000000..501ec3b1a9 --- /dev/null +++ b/evals/elsuite/identifying_variables/metrics.py @@ -0,0 +1,105 @@ +from typing import Dict, List, Set + +import numpy as np + +from evals.elsuite.identifying_variables.utils import parse_solver_preds +from evals.solvers.solver import SolverResult + + +def compute_DCG(ranking: List[float], ceil_negs: bool = False) -> float: + """ + Computes the DCG of a ranking + """ + dcg = 0 + for i, rel in enumerate(ranking, start=1): + if ceil_negs: + rel = max(rel, 0) + dcg += rel / np.log2(i + 1) # (i+1) to avoid log_2(1) which = 0 + return dcg + + +def compute_nDCG(ranking: List[float], best: List[float], worst: List[float]) -> float: + """ + Computes nDCG, allowing for negative scores, based on the nDCG variant + from Gienapp et al. (2020) (https://dl.acm.org/doi/10.1145/3340531.3412123) + """ + idcg = compute_DCG(best) + min_dcg = compute_DCG(worst) + dcg = compute_DCG(ranking) + return (dcg - min_dcg) / (idcg - min_dcg) + + +def compute_metric_posthoc( + metric: str, metric_entries: List[Dict], sampling_entries: List[Dict] +) -> float: + """ + Computes a metric that was not logged by the eval, post-hoc, i.e. + after the eval has run, by reading the log file. + """ + metric_to_func = { + "ctrl_recall": compute_ctrl_recall_posthoc, + } + if metric not in metric_to_func.keys(): + raise ValueError(f"Metric {metric} not supported") + return metric_to_func[metric](metric_entries, sampling_entries) + + +def compute_ctrl_recall_posthoc(metric_entries: List[Dict], sampling_entries: List[Dict]) -> float: + """ + Computes the average recall for identified control variables + + i.e. the no. of correctly identified control variables / no. gold control variables + Averaged across the samples. + + - We skip any samples where the gold hypothesis is invalid + - And we skip any samples where there are no control variables in the gold label, + since recall is undefined in this case + """ + recalls = [] + for metric_entry, sampling_entry in zip(metric_entries, sampling_entries): + try: + preds = parse_solver_preds(SolverResult(output=sampling_entry["sampled"][0])) + except ValueError: # in case of invalid solver output (violation) + preds = None + + if metric_entry["gold_answer"]["valid_hypothesis"]: + if preds and preds.ctrl_vars is not None: + recall = compute_recall( + set(preds.ctrl_vars), set(metric_entry["gold_answer"]["ctrl_vars"]) + ) + else: + # worst case scenario in case of violation or incorrect hyp validation + recall = 0 + else: + recall = np.nan + recalls.append(recall) + return np.nanmean(recalls).astype(float) + + +def compute_fallout(retrieved: Set[str], gold_relevants: Set[str], num_irrelevant: int) -> float: + """ + Computes fallout for a sample + Number of retrieved irrelevant items / number of irrelevant items + + This value is undefined when there are no irrelevant items in the gold label + """ + if num_irrelevant == 0: + # undefined + return np.nan + retrieved_irrel_count = len([r for r in retrieved if r not in gold_relevants]) + return retrieved_irrel_count / num_irrelevant + + +def compute_recall(retrieved: Set[str], gold_relevants: Set[str]): + """ + Computes recall for a sample + Number of retrieved relevant items / number of relevant items + + This value is undefined when there are no relevant items in the gold label + """ + num_relevant = len(gold_relevants) + if num_relevant == 0: + # undefined + return np.nan + retrieved_rel_count = len([r for r in retrieved if r in gold_relevants]) + return retrieved_rel_count / num_relevant diff --git a/evals/elsuite/identifying_variables/prompts.py b/evals/elsuite/identifying_variables/prompts.py new file mode 100644 index 0000000000..4c2eecefda --- /dev/null +++ b/evals/elsuite/identifying_variables/prompts.py @@ -0,0 +1,103 @@ +TASK_DESCRIPTION = """\ +Your task is to determine what the correct independent, dependent\ + and control variables should be when designing an experiment for validating a specific\ + hypothesis, given some observations. + +The observations are organized into N different variables describing some environment.\ + The rules of this environment are unknown. The order in which the variables are\ + presented is arbitrary, and so are their names. In some cases, some if not all\ + observations for a variable may be missing, which simply means that those observations\ + have not been collected for that variable. + +You will also be presented with a list of hypotheses regarding the effect on some\ + variable by some other variable(s). From this list, a "target hypothesis" is chosen for\ + you to focus your choice of variables on. + +You should determine the independent, dependent and control variables based on the above. + +If there is more than one control variable, please separate them with commas, and order\ + them from most important to least important. If there are no control variables, output\ + "None" for the control variable field. Note that you are budget-constrained, so\ + you should only list control variables which would be quite risky not to control, given\ + what is known. + +If the target hypothesis is unlikely to be worth investigating, it is considered\ + invalid, and you should determine this as part of your answer. + +Your answer should be in the following format, including the square brackets: + +[@ANSWER valid_hyp: ; independent: ; dependent: ; control: ] + +For example, \ + [@ANSWER valid_hyp: true; independent: x_1; dependent: x_2; control: x_3, x_4] + +In case of invalid hypothesis, you may ignore the other fields\ + e.g. [@ANSWER valid_hyp: false; independent: NA; dependent: NA; control: NA] or \ + [@ANSWER valid_hyp: false] + +Only perform reasoning if requested. + +{optional_decision_tree_section} +""" + + +DECISION_TREE_SECTION = """\ +To determine the validity and thereby the independent and dependent variables of a\ + hypothesis stating that x_1 is a function of x_2, you can refer to the following\ + pseudocode + +```pseudocode +if correlated(x_1, x_2) is not None and not correlated(x_1, x_2): + "invalid hypothesis" +else: + "independent: x_2; dependent: x_1" +``` + +where `correlated` returns `True` if its arguments are correlated `False` if not,\ + and `None` if it is unknown. + +To determine whether a given variable x_n should be a control variable in an experiment\ + testing a (valid) hypothesis stating that x_1 is a function of x_2, you can refer to the\ + following pseudocode: + +```pseudocode +if x_n in {x_1, x_2}: + "do not control for x_n" +else: + if correlated(x_1, x_n) or correlated(x_2, x_n): + "control for x_n" + else: + if correlated(x_1, x_n) is not None: + "do not control for x_n" + else: + if hypothesized(ind=x_n, dep=x_1, allow_indirect=True): + "control for x_n" + else: + "do not control for x_n" +``` + +where `hypothesized` returns whether `ind` is hypothesized to be a cause of `dep`,\ + even indirectly through chains of hypotheses. +""" + + +SAMPLE_MESSAGE = """\ +Observations: + +{observations} + +Hypotheses: + +{hypotheses} + +Target Hypothesis: + +{target_hypothesis} +""" + + +hypothesis_templates = [ + "{dep} is a function of {ind}", + "{ind} affects {dep} through some function", + "{dep} is affected by {ind} through some function", +] diff --git a/evals/elsuite/identifying_variables/renderers/__init__.py b/evals/elsuite/identifying_variables/renderers/__init__.py new file mode 100644 index 0000000000..c155624761 --- /dev/null +++ b/evals/elsuite/identifying_variables/renderers/__init__.py @@ -0,0 +1,11 @@ +from . import tabular +from . import corrset + +RENDERER_MAP = { + "markdown": tabular.MarkdownTableRenderer, + "csv": tabular.CSVTableRenderer, + "json": tabular.JSONTableRenderer, + "language-tabular": tabular.LanguageTableRenderer, + "language-corrset": corrset.LanguageCorrSetRenderer, + "corrset": corrset.PureCorrSetRenderer, +} diff --git a/evals/elsuite/identifying_variables/renderers/base.py b/evals/elsuite/identifying_variables/renderers/base.py new file mode 100644 index 0000000000..90c1d27ae5 --- /dev/null +++ b/evals/elsuite/identifying_variables/renderers/base.py @@ -0,0 +1,16 @@ +import abc +import random + +import numpy as np + +from evals.elsuite.identifying_variables.structs import Sample + + +class RendererBase(abc.ABC): + def __init__(self, rng: random.Random, np_rng: np.random.Generator) -> None: + self.rng = rng + self.np_rng = np_rng + + @abc.abstractmethod + def render_obs(self, sample: Sample) -> str: + raise NotImplementedError diff --git a/evals/elsuite/identifying_variables/renderers/corrset.py b/evals/elsuite/identifying_variables/renderers/corrset.py new file mode 100644 index 0000000000..39563527a6 --- /dev/null +++ b/evals/elsuite/identifying_variables/renderers/corrset.py @@ -0,0 +1,346 @@ +from typing import List, Set, Tuple + +from evals.elsuite.identifying_variables.structs import Sample +from evals.elsuite.identifying_variables.renderers.base import RendererBase +import evals.elsuite.identifying_variables.graph_utils as graph_utils +import evals.elsuite.identifying_variables.renderers.templates as templates +from evals.elsuite.identifying_variables.constants import SPARSITY_FOR_UNOBS + + +class CorrSetRenderer(RendererBase): + """ + Describes the correlation structure of variables + """ + + def determine_sample_type(self, sample: Sample) -> Tuple[str, List[Set[str]]]: + """ + Determines the type of sample we have, returning the correlation sets in + the process. Accounts for unobserved variables by removing them from + the correlation sets. + + Returns: + str: The type of causal graph we have, ignoring unobserved variables. + Either + - "many_correl_sets": there are at least two correlation sets, at least + one of which has at least two variables. + - "single_correl_set": there is only one correlation set. + - "only_ind": there are at least two correlation sets, all of which + have exactly one variable. + List[Set[str]]: The list of correlation sets. A correlation set is the + set of observed variables in a tree from the causal graph + """ + causal_graph = sample.causal_graph + graph_trees = graph_utils.find_graph_trees(causal_graph) + correl_sets = [] + unobserved_vars = set( + var + for var in sample.variable_metadata + if sample.variable_metadata[var]["extra"]["sparsity_rate"] + > SPARSITY_FOR_UNOBS + ) + for tree in graph_trees: + correl_set = set(tree) + for var in tree: + if var in unobserved_vars: + # correlations to unobserved variables are, well, unobserved + correl_set.remove(var) + correl_sets.append(correl_set) + # need to check for empty sets, since we removed unobserved variables + correl_sets = [correl_set for correl_set in correl_sets if len(correl_set) > 0] + if len(correl_sets) == 1: + return "single_correl_set", correl_sets + else: + for correl_set in correl_sets: + if len(correl_set) > 1: + # at least one set with more than one observed var + return "many_correl_sets", correl_sets + # all sets have only one node + return "only_ind", correl_sets + + def _get_hypd_unobserved_vars(self, sample: Sample) -> List[str]: + vars_to_mention = [] + hypotheses = sample.hypotheses + + hypothesized_vars = set( + var + for var in hypotheses + if hypotheses.in_degree(var) > 0 or hypotheses.out_degree(var) > 0 + ) + vars_to_mention = [ + var + for var in hypothesized_vars + if sample.variable_metadata[var]["extra"]["sparsity_rate"] + > SPARSITY_FOR_UNOBS + ] + return vars_to_mention + + +class PureCorrSetRenderer(CorrSetRenderer): + def render_obs(self, sample: Sample) -> str: + _, observed_sets = self.determine_sample_type(sample) + + render_string = ( + "The following correlation sets were observed. Variables in the" + " same correlation set are correlated with each other, but not with variables in" + " other correlation sets." + ) + render_string += "\n\n" + self._render_observed_sets(observed_sets) + render_string += "\n\n" + self._render_unobserved_vars(sample) + + return render_string + + def _render_observed_sets(self, observed_sets: List[Set[str]]) -> str: + """ + Renders the observed sets. + """ + render_string = "" + for idx, correl_set in enumerate(observed_sets, start=1): + render_string += f"\nCorrelation set {idx}: {{{', '.join(correl_set)}}}." + return render_string.strip() + + def _render_unobserved_vars(self, sample: Sample) -> str: + """ + Renders the unobserved variables. + """ + unobserved_variables = self._get_hypd_unobserved_vars(sample) + if len(unobserved_variables) == 0: + render_string = "There were no unobserved variables." + else: + render_string = f"Unobserved variables: [{', '.join(unobserved_variables)}]." + return render_string.strip() + + +class LanguageCorrSetRenderer(CorrSetRenderer): + """ + Describes the correlation structure of variables in natural language. + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.type_to_renderer = { + "many_correl_sets": self.render_many_sets, + "single_correl_set": self.render_single_set, + "only_ind": self.render_only_ind, + } + + def render_obs(self, sample: Sample) -> str: + """ + Describes the interactions between variables in the sample. + + The description looks like + ``` + {opening statement} + + {description of the interactions} + + {optional mention of unobserved variables that were hypothesized about} + ``` + + The description of the interactions depends on the type of causal graph. + """ + sample_type, observed_sets = self.determine_sample_type(sample) + + opening_statement = templates.OPENING_STATEMENT + main_observation = self.type_to_renderer[sample_type](observed_sets) + unobserved_variables = self.mention_unobserved_vars(sample) + return "\n\n".join([opening_statement, main_observation, unobserved_variables]) + + def render_many_sets(self, correl_sets: List[Set[str]]): + """ + Renders a causal graph where we have at least two correlation + sets, one of which has at least two variables. + The description looks like: + ``` + In general, there were cases where some variables changed in tandem with each + other, while others did not. + {example of two variables that changed in tandem} + {interleaved mentions of remaining variables, specifying which other already + mentioned variables they changed in tandem with, if any} + ``` + """ + # Sort the sets by size, largest first + correl_sets = sorted(correl_sets, key=lambda x: len(x), reverse=True) + variables = [var for correl_set in correl_sets for var in correl_set] + + correl_set_idx_to_already_mentioned_vars = [set() for _ in correl_sets] + var_to_correl_set_idx = { + var: idx for idx, correl_set in enumerate(correl_sets) for var in correl_set + } + return_string = templates.MANY_CORREL_SETS_MAIN + + # hard-code mention first two variables, from first (largest) set + current_set_idx = 0 + return_string += "\n" + templates.CORREL_VARS_EXAMPLE.format( + optional_transition="For example, ", + # the first set is guaranteed to have at least two variables + var_1=variables[0], + var_2=variables[1], + ) + correl_set_idx_to_already_mentioned_vars[0].update([variables[0], variables[1]]) + + # go through remaining variables, randomly + variables = variables[2:] + self.rng.shuffle(variables) + + for var in variables: + correl_set_idx = var_to_correl_set_idx[var] + if correl_set_idx == current_set_idx: + transition_word = self.rng.choice(["Similarly", "Likewise"]) + transition_phrase = f"{transition_word}, " + else: + transition_phrase = "" + current_set_idx = correl_set_idx + + mentioned_vars_from_set = correl_set_idx_to_already_mentioned_vars[ + correl_set_idx + ] + if len(mentioned_vars_from_set) == 0: # first time mentioning this set + mention_string = templates.IND_VARS_EXAMPLE.format( + optional_transition=transition_phrase, + var_1=var, + var_2="previously mentioned variables", + ) + else: # variables from this set have been mentioned + mention_string = templates.CORREL_VARS_EXAMPLE.format( + optional_transition=transition_phrase, + var_1=var, + var_2=templates.list_to_nl_list(list(mentioned_vars_from_set)), + ) + return_string += "\n" + mention_string.capitalize() + # we have now mentioned this variable + correl_set_idx_to_already_mentioned_vars[correl_set_idx].add(var) + + return return_string + + def render_single_set(self, correl_sets: List[Set[str]]) -> str: + """ + Renders a causal graph where we have only one correlation set. + By definition, this set has at least two variables. + The description looks like: + ``` + In general, all of the variables seemed to change in tandem with each other. + For example, changes in {var_1} were observed to reflect changes in {var_2} and + viceversa. + {optional example of other pair} + {optional concluding statement that this holds for all pairs} + ``` + """ + correl_set = correl_sets[0] + # we won't use more than 3 variables in the examples. + exemplar_vars = list(correl_set)[:3] + remaining_vars = correl_set - set(exemplar_vars) + # always have at least 2 vars + example_1 = templates.CORREL_VARS_EXAMPLE.format( + optional_transition="", + var_1=exemplar_vars[0], + var_2=exemplar_vars[1], + ) + example_2 = "" + concluding_statement = "" + if len(exemplar_vars) == 3: + example_2 = templates.CORREL_VARS_EXAMPLE.format( + optional_transition="Additionally, ", + var_1=exemplar_vars[2], + var_2=templates.list_to_nl_list(exemplar_vars[:2]), + ) + if len(remaining_vars) > 0: + concluding_statement = templates.SPECIFIC_CONCL_STATEMENT.format( + already_mentioned=templates.list_to_nl_list(exemplar_vars), + remaining_vars=templates.list_to_nl_list(list(remaining_vars)), + ) + return templates.SINGLE_CORREL_SET_MAIN.format( + example_1=example_1, + optional_example_2=example_2, + optional_concluding_statement=concluding_statement, + ) + + def render_only_ind(self, correl_sets: List[Set[str]]) -> str: + """ + Describes a causal graph where we have at least two correlation + sets, all of which have only one variable, i.e. each variable + in the causal graph is independent of all other variables. The + description looks like: + ``` + In general, no discernible patterns were noticed between the variables. + For example, changes in {var_1} were not observed to reflect any changes in + {var_2}. + {optional example of other pair} + {optional concluding statement that this holds for all pairs} + ``` + """ + variables = [var for correl_set in correl_sets for var in correl_set] + num_vars = len(variables) # equal to the number of sets + # there's always at least 2 variables. + example_1 = templates.IND_VARS_EXAMPLE.format( + optional_transition="", + var_1=variables[0], + var_2=variables[1], + ) + example_2 = "" + concluding_statement = "" + if num_vars > 2: + example_2 = templates.IND_VARS_EXAMPLE.format( + optional_transition="Similarly, ", + var_1=variables[0], + var_2=variables[2], + ) + if num_vars > 3: + concluding_statement = templates.SPECIFIC_CONCL_STATEMENT.format( + already_mentioned=templates.list_to_nl_list(variables[:3]), + remaining_vars=templates.list_to_nl_list(variables[3:]), + ) + else: + concluding_statement = templates.GENERIC_CONCL_STATEMENT + + return templates.ONLY_IND_MAIN.format( + example_1=example_1, + optional_example_2=example_2, + optional_concluding_statement=concluding_statement, + ) + + def mention_unobserved_vars(self, sample: Sample) -> str: + """ + Mentions any unobserved variables that also hypothesized about. + """ + vars_to_mention = self._get_hypd_unobserved_vars(sample) + + n_vars_to_mention = len(vars_to_mention) + if n_vars_to_mention == 0: + return_string = "" + else: + be_plurality = {"singular": "is", "plural": "are"} + be_string = be_plurality["plural" if n_vars_to_mention > 1 else "singular"] + return_string = templates.UNOBS_BUT_HYP_VARS.format( + unobs_but_hyp_vars=templates.list_to_nl_list(vars_to_mention), + be_string=be_string, + ) + return return_string + + +if __name__ == "__main__": + import random + import numpy as np + + list_of_lists = [ + [{"x_1004"}, {"x_1005", "x_1006", "x_1007", "x_1008", "x_1009"}], + [{"x_1007", "x_1008", "x_1009"}, {"x_1010"}], + [{"x_1011"}, {"x_1012", "x_1013"}, {"x_1014"}], # 3 elements + [{"x_1022"}, {"x_1023", "x_1024"}, {"x_1025", "x_1026"}], + [{"x_1030"}, {"x_1031", "x_1032", "x_1033"}, {"x_1034"}, {"x_1035"}], + ] + + np_rng = np.random.default_rng(0) + renderer = PureCorrSetRenderer(random.Random(0), np_rng) + + from evals.elsuite.identifying_variables.scripts.gen_data import gen_samples + import networkx as nx + from pprint import pprint + + samples = gen_samples(10, None, np_rng) + + for sample in samples: + print("causal graph", nx.to_dict_of_lists(sample.causal_graph)) + print("hypotheses", list(sample.hypotheses.edges)) + pprint(sample.variable_metadata) + print(renderer.render_obs(sample)) + print("================") diff --git a/evals/elsuite/identifying_variables/renderers/tabular.py b/evals/elsuite/identifying_variables/renderers/tabular.py new file mode 100644 index 0000000000..0feb8b38fe --- /dev/null +++ b/evals/elsuite/identifying_variables/renderers/tabular.py @@ -0,0 +1,200 @@ +from typing import Optional, Tuple, Union, List +import json +import random + +import networkx as nx +import numpy as np +import pandas as pd + +from evals.elsuite.identifying_variables.structs import Sample +from evals.elsuite.identifying_variables.renderers.base import RendererBase +from evals.elsuite.identifying_variables.latent_funcs import ( + DISTRIBUTIONS, + LATENT_FUNC_MAP, +) +from evals.elsuite.identifying_variables.constants import NUM_OBS + + +def apply_noise( + data_df: pd.DataFrame, np_rng: np.random.Generator, snr: Optional[float] = None +) -> pd.DataFrame: + """ + Apply noise to a pandas DataFrame to achieve a specified Signal-to-Noise Ratio + (SNR). + + Args: + data_df (pd.DataFrame): The DataFrame containing the original data. + snr (float): The desired Signal-to-Noise Ratio in decibels (dB). + If None, no noise is applied. + """ + if snr is None: + return data_df + + desired_snr_linear = 10 ** (snr / 10) + + signal_powers = data_df.var() + noise_powers = signal_powers / desired_snr_linear + + noise = pd.DataFrame( + np_rng.normal(0, np.sqrt(noise_powers), data_df.shape), + columns=data_df.columns, + ) + noisy_df = data_df + noise + + return noisy_df + + +def sparsify_data( + data_df: pd.DataFrame, variable_metadata: dict, np_rng: np.random.Generator +) -> pd.DataFrame: + total_obs = data_df.shape[0] + for var in variable_metadata.keys(): + sparsity_rate = variable_metadata[var]["extra"]["sparsity_rate"] + num_missing_obs = int(sparsity_rate * total_obs) + missing_obs_indices = np_rng.choice(total_obs, num_missing_obs, replace=False) + data_df.loc[missing_obs_indices, var] = np.nan + return data_df + + +class TabularRenderer(RendererBase): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.num_obs = NUM_OBS + + def _render_table(self, sample: Sample) -> pd.DataFrame: + variable_metadata = sample.variable_metadata + sample_metadata = sample.sample_metadata + n_obs_samples = self.num_obs + causal_graph = sample.causal_graph + + # "topological sort" from least to most ancestors (i.e. least to most dependent) + sorted_vars = nx.topological_sort(causal_graph) + # necessary so that we can generate data in the correct order + + data_dict = {} + for var in sorted_vars: + gen_method = variable_metadata[var]["gen_method"]["name"] + if "input_x" not in variable_metadata[var]["gen_method"]: + distr = DISTRIBUTIONS[gen_method] + distr_kwargs = variable_metadata[var]["gen_method"]["kwargs"] + data_dict[var] = distr( + num_samples=n_obs_samples, **distr_kwargs, rng=self.np_rng + ) + else: + latent_func = LATENT_FUNC_MAP[gen_method] + latent_func_kwargs = variable_metadata[var]["gen_method"]["kwargs"] + input_x = variable_metadata[var]["gen_method"]["input_x"] + data_dict[var] = latent_func(x=data_dict[input_x], **latent_func_kwargs) + + data_df = pd.DataFrame(data_dict) + + # apply noise after generating data + data_df = apply_noise(data_df, self.np_rng, sample_metadata["snr"]) + # apply sparsification after generating and noise + data_df = sparsify_data(data_df, variable_metadata, self.np_rng) + + # round to 3 decimal places + data_df = data_df.round(3) + + return data_df + + +class MarkdownTableRenderer(TabularRenderer): + """ + Renders tabular data as a markdown table with variable names as column names. + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def render_obs(self, sample: Sample) -> str: + data_df = self._render_table(sample) + return data_df.to_markdown(index=False) + + +class CSVTableRenderer(TabularRenderer): + """ + Renders tabular data as a comma-separated-values (CSV) file with variable names as + column names. + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def render_obs(self, sample: Sample) -> str: + data_df = self._render_table(sample) + return data_df.to_csv(index=False) + + +class JSONTableRenderer(TabularRenderer): + """ + Renders tabular data as a JSON object with variable names as keys and lists of + values as values. + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def render_obs(self, sample: Sample) -> str: + data_df = self._render_table(sample) + return json.dumps(data_df.to_dict(orient="list")) + + +class LanguageTableRenderer(TabularRenderer): + """ + Renders tabular data as a natural language description of the data. + Describing the data row by row. + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.num_obs = 10 # set it to 10 + # realistically no one would read more than 10 rows of data one by one + + def render_obs(self, sample: Sample) -> str: + data_df = self._render_table(sample) + variables = list(data_df.columns) + rendered_obs = "" + current_step = "first" + for row in data_df.itertuples(index=False, name=None): + rendered_obs += self._render_row(row, variables, current_step) + "\n" + current_step = "next" + return rendered_obs + + def _render_row( + self, row: Tuple[Union[int, float]], variables: List[str], current_step: str + ) -> str: + string = f"On the {current_step} step, " + past_participle_verb = self.rng.choice(["measured", "recorded", "reported"]) + for value, var in zip(row, variables): + if np.isnan(value): + string += f"{var} was not {past_participle_verb}. " + else: + string += ( + f"{var} was {past_participle_verb} to be {format_number(value)}. " + ) + return string + + +def format_number(number: Union[int, float]): + """Get's rid of trailing .0's""" + if float(number).is_integer(): + return int(number) + else: + return number + + +if __name__ == "__main__": + # just for quick testing + np_rng = np.random.default_rng(0) + renderer = LanguageTableRenderer(random.Random(0), np_rng) + + from evals.elsuite.identifying_variables.scripts.gen_data import gen_samples + + samples = gen_samples(10, None, np_rng) + + for sample in samples: + print(nx.to_dict_of_lists(sample.causal_graph)) + print(sample.variable_metadata) + print(renderer.render_obs(sample)) + print("================") diff --git a/evals/elsuite/identifying_variables/renderers/templates.py b/evals/elsuite/identifying_variables/renderers/templates.py new file mode 100644 index 0000000000..c7a9000072 --- /dev/null +++ b/evals/elsuite/identifying_variables/renderers/templates.py @@ -0,0 +1,56 @@ +from typing import List + + +def list_to_nl_list(list_of_words: List[str]) -> str: + """ + Converts a list of words into a natural language list. + """ + if len(list_of_words) == 1: + return list_of_words[0] + elif len(list_of_words) == 2: + return f"{list_of_words[0]} and {list_of_words[1]}" + else: + return f"{', '.join(list_of_words[:-1])} and {list_of_words[-1]}" + + +OPENING_STATEMENT = """\ +The following is a description of the observations made about a set of variables. +""".strip() + +MANY_CORREL_SETS_MAIN = """\ +In general, there were cases where some variables changed in tandem with each other,\ + while others did not. +""".strip() + +SINGLE_CORREL_SET_MAIN = """\ +In general, all of the variables seemed to change in tandem with each other. +For example, {example_1} {optional_example_2} {optional_concluding_statement} +""".strip() + +ONLY_IND_MAIN = """\ +In general, no discernible patterns were noticed between the variables. +For example, {example_1} {optional_example_2} {optional_concluding_statement} +""".strip() + +CORREL_VARS_EXAMPLE = """\ +{optional_transition}changes in {var_1} were observed to reflect changes in {var_2} and\ + viceversa. +""".strip() + +IND_VARS_EXAMPLE = """\ +{optional_transition}changes in {var_1} were not observed to reflect any changes in\ + {var_2}. +""".strip() + +SPECIFIC_CONCL_STATEMENT = """\ +Similar observations were made for all other pairings within and across\ + {already_mentioned} and {remaining_vars}. +""".strip() + +GENERIC_CONCL_STATEMENT = """\ +Similar observations were made for all other pairings of the observed variables. +""".strip() + +UNOBS_BUT_HYP_VARS = """\ +{unobs_but_hyp_vars} {be_string} not observed but {be_string} hypothesized about. +""".strip() diff --git a/evals/elsuite/identifying_variables/scripts/data.sh b/evals/elsuite/identifying_variables/scripts/data.sh new file mode 100755 index 0000000000..418ebe3fef --- /dev/null +++ b/evals/elsuite/identifying_variables/scripts/data.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +# generate datasets of size 500 and 5000 +echo "Generating default dataset: 500 samples" +python gen_data.py --n_samples 500 --jsonl_dir ../../../registry/data/identifying_variables/ +echo "Generating large dataset: 5000 samples" +python gen_data.py --n_samples 5000 --jsonl_dir ../../../registry/data/identifying_variables/ +echo "Generating default dataset: 500 samples (balanced ctrl vars)" +python gen_data.py --balanced_ctrl_vars --n_samples 500 --jsonl_dir ../../../registry/data/identifying_variables/ +echo "Generating large dataset: 5000 samples (balanced ctrl vars)" +python gen_data.py --balanced_ctrl_vars --n_samples 5000 --jsonl_dir ../../../registry/data/identifying_variables/ + +echo "Done." diff --git a/evals/elsuite/identifying_variables/scripts/gen_data.py b/evals/elsuite/identifying_variables/scripts/gen_data.py new file mode 100644 index 0000000000..14c5f78e28 --- /dev/null +++ b/evals/elsuite/identifying_variables/scripts/gen_data.py @@ -0,0 +1,467 @@ +""" +Code for generating .jsonl dataset for identifying variables eval + +Use default argparse args to replicate the dataset used for the report +""" + +from dataclasses import asdict +import os +import argparse +from typing import Dict, List, Optional, Set, Tuple, Any +import json +import copy + +from tqdm.auto import tqdm +import networkx as nx +import numpy as np + +import evals.elsuite.identifying_variables.latent_funcs as latent_funcs +from evals.elsuite.identifying_variables.graph_utils import ( + gen_random_forest, + gen_random_forest_tree_size, + find_graph_roots, + find_unconnected_nodes_pair, + find_connected_nodes_pair, +) +from evals.elsuite.identifying_variables.utils import sample_serializer +from evals.elsuite.identifying_variables.structs import Sample, Answer +import evals.elsuite.identifying_variables.constants as constants + + +def write_to_jsonl( + samples: List[Sample], + jsonl_path: str, +): + with open(jsonl_path, "w") as f: + for sample in samples: + f.write(json.dumps(asdict(sample), default=sample_serializer) + "\n") + + +def random_latent_func_meta( + np_rng: np.random.Generator, input_x: Optional[str] = None +) -> Dict: + """ + Generates random metadata for defining a latent function + + Args: + input_x (Optional[str]): Name of input variable. If None, then + the latent function is a distribution, not dependent on any input. + """ + if input_x is None: + latent_func_name = np_rng.choice(list(latent_funcs.DISTRIBUTIONS.keys())) + predefined_kwargs = latent_funcs.DISTRIBUTIONS_KWARG_MAP[latent_func_name] + kwargs = {**predefined_kwargs} + return {"name": latent_func_name, "kwargs": kwargs} + else: + latent_func_name = np_rng.choice(list(latent_funcs.LATENT_FUNC_MAP.keys())) + predefined_kwargs = latent_funcs.LATENT_FUNC_KWARG_MAP[latent_func_name] + kwargs = {} + for kwarg, min_max in predefined_kwargs.items(): + kwarg_value = np_rng.integers(min_max["min_v"], min_max["max_v"]) + while kwarg == "grad" and kwarg_value == 0: + # dont allow 0 gradient + kwarg_value = np_rng.integers(min_max["min_v"], min_max["max_v"]) + kwargs[kwarg] = kwarg_value + return {"name": latent_func_name, "input_x": input_x, "kwargs": kwargs} + + +def build_var_metadata( + causal_graph: nx.DiGraph, + sparse_var_rate: float, + np_rng: np.random.Generator, +) -> Dict: + """ + Builds the variable metadata for a sample, containing + information on how each variable is generated and which variables + it is correlated with. + + Args: + causal_graph (nx.DiGraph): Causal graph of the sample. + sparse_var_rate (float): Percentage of variables that should be sparsified. + max_sparsity (float): Maximum sparsity rate for sparse variables. + np_rng (np.random.Generator): Random number generator to be used. + """ + var_metadata = {} + + roots = find_graph_roots(causal_graph) + root_to_descendants = {r: nx.descendants(causal_graph, r) for r in roots} + node_to_root = { + n: root + for root, descendants in root_to_descendants.items() + for n in descendants + } + + for var in causal_graph: + if var in roots: + latent_func_meta = random_latent_func_meta(np_rng, input_x=None) + var_root = var + else: + parent = next(causal_graph.predecessors(var)) + latent_func_meta = random_latent_func_meta(np_rng, input_x=parent) + var_root = node_to_root[var] + # variables with a common root are correlated. Need to copy to avoid mutation + corrs: Set[str] = set(root_to_descendants[var_root]) + if var_root != var: + # remove self-correlation, add correlation to root itself + corrs.remove(var) + corrs.add(var_root) + + var_metadata[var] = { + "gen_method": latent_func_meta, + "corrs": corrs, + "extra": {"sparsity_rate": 0}, + } + + # add sparsity + var_metadata = sparsify_data(var_metadata, sparse_var_rate, np_rng) + + return var_metadata + + +def sparsify_data(var_metadata, sparse_var_rate, np_rng): + num_observed_vars = 0 + orig_var_metadata = copy.deepcopy(var_metadata) + for var in var_metadata.keys(): + if np_rng.uniform(0, 1) < sparse_var_rate: + sparsity_rate = np_rng.uniform( + low=constants.MIN_SPARSITY, high=constants.MAX_SPARSITY + ) + var_metadata[var]["extra"]["sparsity_rate"] = sparsity_rate + if sparsity_rate > constants.SPARSITY_FOR_UNOBS: + # remove unobserved variables from correlations + for corr_var in var_metadata[var]["corrs"]: + var_metadata[corr_var]["corrs"].remove(var) + var_metadata[var]["corrs"] = set() + else: + num_observed_vars += 1 + else: + num_observed_vars += 1 + + # if less than 2 observed variables, sparsification was too much, try again + if num_observed_vars < 2: + var_metadata = sparsify_data(orig_var_metadata, sparse_var_rate, np_rng) + + return var_metadata + + +def gen_sample_balanced_ctrl_vars( + signal_noise_ratio: Optional[float], np_rng: np.random.Generator +) -> Sample: + """ + Generates a sample for the dataset, containing information on how a set + of variables are interlinked, and which hypotheses are currently held. + + This differs from gen_sample in the following ways: + + To simplify: + - The total number of variables in a given sample is fixed to MAX_VARS + - The hypothesis is always valid + + The number of control variables is sampled uniformly between 0 and MAX_VARS-2 + (we subtract 2 since two variables are involved in the hypothesis) + """ + sample_metadata = {"snr": signal_noise_ratio} + + n_vars = constants.MAX_VARS + + sparse_var_rate = np_rng.uniform( + low=constants.MIN_SPARSE_VAR_RATE, high=constants.MAX_SPARSE_VAR_RATE + ) # perc of variables to sparsify + + var_ids = np_rng.choice(np.arange(1000, 10000), size=n_vars, replace=False).astype( + str + ) + var_names = [f"x_{var_id}" for var_id in var_ids] + + num_ctrl_vars = np_rng.integers(low=0, high=n_vars - 1) # high is exclusive + + causal_graph = gen_random_forest_tree_size( + nodes=var_names, tree_size=num_ctrl_vars + 2, np_rng=np_rng + ) + + variable_metadata = build_var_metadata(causal_graph, sparse_var_rate, np_rng) + + target_hypothesis = find_connected_nodes_pair(causal_graph, np_rng) + target_hyp_is_valid = ( + parse_target_hyp(target_hypothesis, variable_metadata)[0] + if target_hypothesis is not None + else None + ) + # try again if the sparsification caused the hypothesis to be invalid + if target_hypothesis is None or not target_hyp_is_valid: + return gen_sample_balanced_ctrl_vars(signal_noise_ratio, np_rng) + + n_hypotheses = np_rng.integers( + low=constants.MIN_HYPS, + high=min(constants.MAX_HYPS, n_vars - 1) + 1, + ) + hypotheses = gen_random_forest(var_names, total_edges=n_hypotheses, np_rng=np_rng) + + hypotheses = integrate_target_hyp(target_hypothesis, hypotheses, np_rng) + + gold_label, num_not_ctrl = determine_gold_label( + target_hypothesis, variable_metadata, hypotheses + ) + + return Sample( + variable_metadata=variable_metadata, + hypotheses=hypotheses, + target_hypothesis=target_hypothesis, + sample_metadata=sample_metadata, + # keep track of underlying ground truth in case want more in depth analysis + causal_graph=causal_graph, + gold_label=gold_label, + num_not_ctrl=num_not_ctrl, + ) + + +def gen_sample( + signal_noise_ratio: Optional[float], + np_rng: np.random.Generator, + valid_hyp_requested: Optional[bool] = None, +) -> Sample: + """ + Generates a sample for the dataset, containing information on how a set + of variables are interlinked, and which hypotheses are currently held. + + Args: + signal_noise_ratio (float): Signal-to-noise ratio to be applied to the + observations. If None, no noise is applied. + np_rng (np.random.Generator): Random number generator to be used. + valid_hyp_requested (Optional[bool]): Whether the target hypothesis should be + valid. If None, will be randomly chosen. + + Returns: + Sample: A sample as defined by the `Sample` dataclass. + """ + sample_metadata = {"snr": signal_noise_ratio} + + n_vars = np_rng.integers(low=constants.MIN_VARS, high=constants.MAX_VARS + 1) + sparse_var_rate = np_rng.uniform( + low=constants.MIN_SPARSE_VAR_RATE, high=constants.MAX_SPARSE_VAR_RATE + ) # perc of variables to sparsify + + var_ids = np_rng.choice(np.arange(1000, 10000), size=n_vars, replace=False).astype( + str + ) + var_names = [f"x_{var_id}" for var_id in var_ids] + + causal_graph = gen_random_forest(var_names, np_rng=np_rng) + + variable_metadata = build_var_metadata(causal_graph, sparse_var_rate, np_rng) + + n_hypotheses = np_rng.integers( + low=constants.MIN_HYPS, + high=min(constants.MAX_HYPS, n_vars - 1) + 1, + ) + hypotheses = gen_random_forest(var_names, total_edges=n_hypotheses, np_rng=np_rng) + + if valid_hyp_requested is None: + # 0.5 chance of valid hypothesis + valid_hyp_requested = np_rng.uniform(0, 1) < 0.5 + + if valid_hyp_requested: + target_hypothesis = find_connected_nodes_pair(causal_graph, np_rng) + else: + target_hypothesis = find_unconnected_nodes_pair(causal_graph) + + target_hyp_is_valid = ( + parse_target_hyp(target_hypothesis, variable_metadata)[0] + if target_hypothesis is not None + else None + ) + if target_hypothesis is None or target_hyp_is_valid != valid_hyp_requested: + return gen_sample(signal_noise_ratio, np_rng, valid_hyp_requested) + + hypotheses = integrate_target_hyp(target_hypothesis, hypotheses, np_rng) + + gold_label, num_not_ctrl = determine_gold_label( + target_hypothesis, variable_metadata, hypotheses + ) + + return Sample( + variable_metadata=variable_metadata, + hypotheses=hypotheses, + target_hypothesis=target_hypothesis, + sample_metadata=sample_metadata, + # keep track of underlying ground truth in case want more in depth analysis + causal_graph=causal_graph, + gold_label=gold_label, + num_not_ctrl=num_not_ctrl, + ) + + +def determine_gold_label( + target_hyp, variable_metadata, hypotheses +) -> Tuple[Answer, Optional[int]]: + """ + Determines the ideal `Answer` for a given sample. Additionally returns + the number of variables not controlled for, if the hypothesis is valid, + necessary for nDCG calculation. + """ + valid_hypothesis, ind_var, dep_var = parse_target_hyp(target_hyp, variable_metadata) + if not valid_hypothesis: + ctrl_vars, not_ctrls = None, None + num_not_ctrl = None + else: + ctrl_vars, not_ctrls = determine_ctrl_vars( + variable_metadata, ind_var, dep_var, hypotheses + ) + # worst case ctrl: all vars that aren't meant to be ctrld are ctrld + num_not_ctrl = len(not_ctrls) + + return ( + Answer( + valid_hypothesis=valid_hypothesis, + ind_var=ind_var, + dep_var=dep_var, + ctrl_vars=ctrl_vars, + ), + num_not_ctrl, + ) + + +def parse_target_hyp( + target_hyp: Tuple[str, str], variable_metadata: Dict[str, Any] +) -> Tuple[bool, Optional[str], Optional[str]]: + """Implements decision tree in Figure 2 from eval spec""" + proposed_ind = target_hyp[0] + proposed_dep = target_hyp[1] + + ind_unobserved = ( + variable_metadata[proposed_ind]["extra"]["sparsity_rate"] + > constants.SPARSITY_FOR_UNOBS + ) + dep_unobserved = ( + variable_metadata[proposed_dep]["extra"]["sparsity_rate"] + > constants.SPARSITY_FOR_UNOBS + ) + + # if either are unobserved, we have no evidence that they are not correlated + if ind_unobserved or dep_unobserved: + return True, proposed_ind, proposed_dep + # evidence of lack of correlation + elif proposed_dep not in variable_metadata[proposed_ind]["corrs"]: + return False, None, None + # evidence of correlation + else: + return True, proposed_ind, proposed_dep + + +def determine_ctrl_vars( + variable_metadata: Dict[str, Any], + ind_var: str, + dep_var: str, + hypotheses: nx.DiGraph, +) -> Tuple[List[str], List[str]]: + """Implements decision tree in Figure 1 from eval spec""" + ctrl_vars = [] + not_ctrls = [] + for var in variable_metadata: + if var in {ind_var, dep_var}: + not_ctrls.append(var) + elif are_correlated(var, dep_var, variable_metadata) or are_correlated( + var, ind_var, variable_metadata + ): + ctrl_vars.append(var) + elif are_correlated(var, dep_var, variable_metadata) is not None: + # don't control vars which we have observed to be uncorrelated w/ dep + not_ctrls.append(var) + else: # when dep_var or var is unobserved, no evidence of lack of correlation + # control for any var which might influence the dependent variable + dep_var_ancestors = nx.ancestors(hypotheses, dep_var) + if var in dep_var_ancestors: + ctrl_vars.append(var) + else: + not_ctrls.append(var) + + return ctrl_vars, not_ctrls + + +def are_correlated(var_1, var_2, variable_metadata) -> Optional[bool]: + """ + Returns whether two variables are correlated. If there is no evidence + of correlation, returns None. + """ + if ( + variable_metadata[var_1]["extra"]["sparsity_rate"] + > constants.SPARSITY_FOR_UNOBS + or variable_metadata[var_2]["extra"]["sparsity_rate"] + > constants.SPARSITY_FOR_UNOBS + ): + return None + return ( + var_2 in variable_metadata[var_1]["corrs"] + or var_1 in variable_metadata[var_2]["corrs"] + ) + + +def integrate_target_hyp( + target_hyp: Tuple[Any, Any], hyp_graph: nx.DiGraph, np_rng: np.random.Generator +): + """ + Integrates the target hypothesis into the hypotheses graph, respecting + the original edge count by removing a random edge if necessary. + """ + if not hyp_graph.has_edge(*target_hyp): + random_edge_to_remove = np_rng.choice(list(hyp_graph.edges)) + hyp_graph.remove_edge(*random_edge_to_remove) + hyp_graph.add_edge(*target_hyp) + return hyp_graph + + +def gen_samples( + n_samples: int, + signal_noise_ratio: Optional[float], + np_rng: np.random.Generator, + balanced_ctrl_vars: bool = False, +) -> List[Sample]: + samples = [] + if not balanced_ctrl_vars: + for _ in tqdm(range(n_samples)): + sample = gen_sample(signal_noise_ratio, np_rng) + samples.append(sample) + else: + for _ in tqdm(range(n_samples)): + sample = gen_sample_balanced_ctrl_vars(signal_noise_ratio, np_rng) + samples.append(sample) + + return samples + + +def main(args: argparse.Namespace): + np_rng = np.random.default_rng(args.seed) + samples = gen_samples(args.n_samples, args.snr, np_rng, args.balanced_ctrl_vars) + os.makedirs(args.jsonl_dir, exist_ok=True) + if not args.balanced_ctrl_vars: + jsonl_path = os.path.join(args.jsonl_dir, f"{args.n_samples}.jsonl") + else: + jsonl_path = os.path.join( + args.jsonl_dir, f"{args.n_samples}_balanced_ctrl_vars.jsonl" + ) + write_to_jsonl(samples, jsonl_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description=__doc__) + + parser.add_argument("--n_samples", type=int, default=5000) + parser.add_argument( + "--snr", + type=float, + default=None, + help="signal-to-noise ratio. Default None (no noise is applied.)", + ) + parser.add_argument( + "--jsonl_dir", type=str, default="./evals/registry/data/identifying_variables/" + ) + parser.add_argument("--seed", type=int, default=20220722) + parser.add_argument( + "--balanced_ctrl_vars", + action="store_true", + help="Whether to generate samples with balanced control variables.", + default=False, + ) + args = parser.parse_args() + + main(args) diff --git a/evals/elsuite/identifying_variables/scripts/make_plots.py b/evals/elsuite/identifying_variables/scripts/make_plots.py new file mode 100644 index 0000000000..f29f781492 --- /dev/null +++ b/evals/elsuite/identifying_variables/scripts/make_plots.py @@ -0,0 +1,400 @@ +from pathlib import Path +from typing import Dict, Tuple + +import numpy as np +import pandas as pd +from tqdm.auto import tqdm + +from evals.elsuite.identifying_variables.metrics import compute_metric_posthoc +from evals.elsuite.identifying_variables.scripts.plotting_utils import ( + plot_difficulty_bars, + plot_solver_bars, +) +from evals.elsuite.identifying_variables.scripts.table_utils import ( + make_main_metric_table, +) +from evals.utils import log_utils + +NUM_REPEATS = 3 +MAIN_METRICS = [ + "ctrl_nDCG", + "ctrl_recall", + "hyp_valid_acc", + "ind_acc", + "dep_acc", + "violation_rate", +] + +SOLVERS = [ + "generation/direct/gpt-3.5-turbo", + "generation/cot/gpt-3.5-turbo", + "generation/hhh/gpt-4-base", + "generation/cot_hhh/gpt-4-base", + "generation/direct/gpt-4-1106-preview", + "generation/cot/gpt-4-1106-preview", + "generation/cot/mixtral-8x7b-instruct", + "generation/cot/llama-2-70b-chat", + "generation/cot/gemini-pro", + "identifying_variables/random", + "identifying_variables/noctrl", +] + + +RENDERERS = [ + "markdown", + "csv", + "json", + "language-tabular", + "language-corrset", + "corrset", +] + + +def initialize_default_results_dict(): + results_dict = { + metric: { + stat: { + solver: { + renderer: { + "with tree": ([] if stat == "raw" else 0), + "without tree": ([] if stat == "raw" else 0), + } + for renderer in RENDERERS + } + for solver in SOLVERS + } + for stat in ["raw", "mean", "sem"] + } + for metric in MAIN_METRICS + } + return results_dict + + +def handle_cot_double_sampling(sampling_entries, solver): + if "cot" in solver: + sampling_entries = [ + entry + for entry in sampling_entries + if ( + # for chat models we filter like this + isinstance(entry["prompt"], list) + and entry["prompt"][-1]["content"].startswith( + "Given the above reasoning" + ) + or ( + # for base models we need to filter like this + isinstance(entry["prompt"], str) + and "Given the above reasoning" in entry["prompt"] + ) + ) + ] + return sampling_entries + + +def handle_posthoc_metrics(final_results: Dict, log_path: Path, solver: str): + """ + Computes and includes missing metrics from log file if they are not present + """ + metric_entries = log_utils.extract_individual_results(log_path) + sampling_entries = log_utils.extract_individual_results(log_path, "sampling") + # filter out cot double samplings + sampling_entries = handle_cot_double_sampling(sampling_entries, solver) + # this is necessary because we originally didnt compute recall in the eval + for metric in MAIN_METRICS: + if metric not in final_results.keys(): + final_results[metric] = compute_metric_posthoc( + metric, metric_entries, sampling_entries + ) + + return final_results + + +def populate_default_results_dict(results_dict, results_dir): + for log in tqdm(results_dir.glob("*.log"), total=222): + spec = log_utils.extract_spec(log) + solver = spec["completion_fns"][0] + run_config = spec["run_config"] + renderer = run_config["eval_spec"]["args"]["renderer"] + show_tree = "show_tree=True" in run_config["command"] + tree_key = "with tree" if show_tree else "without tree" + if renderer not in RENDERERS and solver != "identifying_variables/random": + continue + if solver not in SOLVERS: + continue + + final_results = log_utils.extract_final_results(log) + final_results = handle_posthoc_metrics(final_results, log, solver) + + for metric, value in final_results.items(): + if metric in MAIN_METRICS: + results_dict[metric]["raw"][solver][renderer][tree_key].append(value) + raw = results_dict[metric]["raw"][solver][renderer][tree_key] + results_dict[metric]["mean"][solver][renderer][tree_key] = np.mean(raw) + results_dict[metric]["sem"][solver][renderer][tree_key] = np.std( + raw + ) / np.sqrt(NUM_REPEATS) + for metric in results_dict.keys(): + del results_dict[metric]["raw"] + return results_dict + + +def make_default_tables(results_dict: Dict, save_dir: Path): + for metric in tqdm(MAIN_METRICS): + make_main_metric_table(results_dict, metric, SOLVERS, RENDERERS, save_dir) + + +def extract_default_results_dict(results_dir: Path): + results_dict = initialize_default_results_dict() + results_dict = populate_default_results_dict(results_dict, results_dir) + + return results_dict + + +def make_default_plots(results_dict: Dict, save_dir: Path): + all_solvers = list(results_dict["ctrl_nDCG"]["mean"].keys()) + bar_solvers, baseline_solvers = all_solvers[:-2], all_solvers[-2:] + + metrics = ["ctrl_nDCG", "ctrl_recall"] + metric_labels = ["Control Variable Retrieval nDCG*", "Control Variable Recall"] + fig_heights = [6, 5] + + for metric, metric_label, fig_height in tqdm( + zip(metrics, metric_labels, fig_heights) + ): + plot_solver_bars( + bar_solvers, + baseline_solvers, + results_dict[metric], + metric_label, + fig_height, + save_dir / f"{metric}.png", + ) + + +def extract_large_results_dict(results_dir: Path) -> Dict: + ctrl_nDCG_bins = list(range(0, 9)) + results_dict = {} + for log in tqdm(results_dir.glob("*.log"), total=12): + spec = log_utils.extract_spec(log) + final_results = log_utils.extract_final_results(log) + solver = spec["completion_fns"][0] + renderer = spec["split"] + key = f"{solver};{renderer}" + if key not in results_dict: + results_dict[key] = { + bbin: {"raw": [], "mean": None, "sem": None} for bbin in ctrl_nDCG_bins + } + + for bbin in ctrl_nDCG_bins: + results_dict[key][bbin]["raw"].append( + final_results[f"ctrl_nDCG-n_ctrl_vars-{bbin}"] + ) + for key in results_dict.keys(): + for bbin in ctrl_nDCG_bins: + mean = np.mean(results_dict[key][bbin]["raw"]) + sem = np.std(results_dict[key][bbin]["raw"]) / 3 + results_dict[key][bbin]["mean"] = mean + results_dict[key][bbin]["sem"] = sem + del results_dict[key][bbin]["raw"] + + return results_dict + + +def make_large_plot(large_results_dir: Dict, save_dir: Path): + ctrl_vars_bins = list(range(0, 9)) + plot_difficulty_bars( + large_results_dir, ctrl_vars_bins, save_dir / "ctrl_nDCG_difficulty.png" + ) + + +def np_nan_if_none(input_num): + if input_num is None: + return np.nan + else: + return input_num + + +def zero_if_none(input_num): + if input_num is None: + return 0 + else: + return input_num + + +def round_if_not_nan(input_num): + if np.isnan(input_num): + return input_num + else: + return round(input_num) + + +def make_token_per_sample_df(solver_to_eval, solver_to_tokens) -> pd.DataFrame: + tokens_per_sample_df = pd.DataFrame( + index=solver_to_eval.keys(), + columns=[ + "input tokens/sample", + "output tokens/sample", + "total tokens/sample", + ], + ) + for solver in solver_to_tokens.keys(): + # print(solver_to_tokens[solver]) + input_mean = np.nanmean(solver_to_tokens[solver]["input"]) + output_mean = np.nanmean(solver_to_tokens[solver]["output"]) + total_mean = np.nanmean(solver_to_tokens[solver]["total"]) + # print([input_mean, output_mean, total_mean]) + tokens_per_sample_df.loc[solver] = [ + round_if_not_nan(input_mean), + round_if_not_nan(output_mean), + round_if_not_nan(total_mean), + ] + solver_to_index = { + "generation/hhh/gpt-4-base": "HHH GPT-4-base (corrset, no tree)", + "generation/direct/gpt-3.5-turbo": "Direct GPT-3.5-turbo (corrset, no tree)", + "generation/direct/gpt-4-1106-preview": "Direct GPT-4-1106-preview (corrset, no tree)", + "generation/cot_hhh/gpt-4-base": "CoT HHH GPT-4-base (language-tabular, with tree)", + "generation/cot/gpt-3.5-turbo": "CoT GPT-3.5-turbo (language-tabular, with tree)", + "generation/cot/gpt-4-1106-preview": "CoT GPT-4-1106-preview (language-tabular, with tree)", + } + tokens_per_sample_df = tokens_per_sample_df.rename(index=solver_to_index) + return tokens_per_sample_df + + +def count_tokens(results_dir: Path, total) -> Tuple[Dict, pd.DataFrame]: + eval_names = [ + "identifying_variables.corrset.default", + "identifying_variables.language-tabular.default", + ] + solver_names = [ + "generation/hhh/gpt-4-base", + "generation/direct/gpt-3.5-turbo", + "generation/direct/gpt-4-1106-preview", + "generation/cot_hhh/gpt-4-base", + "generation/cot/gpt-3.5-turbo", + "generation/cot/gpt-4-1106-preview", + ] + solver_to_eval = { + solver: eval_names[0] if "cot" not in solver else eval_names[1] + for solver in solver_names + } + solver_to_tree = { + solver: False if "cot" not in solver else True for solver in solver_names + } + solver_to_tokens = { + solver: {"input": [], "output": [], "total": []} for solver in solver_names + } + total_input = 0 + total_output = 0 + for log in tqdm(results_dir.glob("*.log"), total=total): + spec = log_utils.extract_spec(log) + solver = spec["completion_fns"][0] + if solver not in solver_names: + print(f"Skipping {solver}: token counting not supported.") + continue + eval_name = spec["eval_name"] + seed = spec["run_config"]["seed"] + tree = "show_tree=True" in spec["run_config"]["command"] + samplings = log_utils.extract_individual_results(log, "sampling") + samplings = handle_cot_double_sampling(samplings, solver) + for sampling in samplings: + usage = sampling["usage"] + if ( + solver in solver_to_eval + and eval_name == solver_to_eval[solver] + and seed == 1 + and tree != solver_to_tree[solver] + ): + solver_to_tokens[solver]["input"].append( + np_nan_if_none(usage["prompt_tokens"]) + ) + solver_to_tokens[solver]["output"].append( + np_nan_if_none(usage["completion_tokens"]) + ) + solver_to_tokens[solver]["total"].append( + np_nan_if_none(usage["total_tokens"]) + ) + total_input += zero_if_none(usage["prompt_tokens"]) + total_output += zero_if_none(usage["completion_tokens"]) + + total_tokens = {"input": total_input, "output": total_output} + tokens_per_sample_df = make_token_per_sample_df(solver_to_eval, solver_to_tokens) + + return total_tokens, tokens_per_sample_df + + +def make_total_tokens_table(default_total: Dict, large_total: Dict) -> pd.DataFrame: + """ + Makes a dataframe where the index is "default" "large" and the columns are + "input", "output"; showing the total number of input and output tokens for + our experiments on each dataset. + """ + total_tokens_df = pd.DataFrame( + { + "input": [default_total["input"], large_total["input"]], + "output": [default_total["output"], large_total["output"]], + }, + index=["default", "large"], + ) + return total_tokens_df + + +def make_token_count_tables( + default_results_dir: Path, large_results_dir: Path, save_dir: Path +): + default_total_tokens, default_per_sample_tokens_df = count_tokens( + default_results_dir, total=222 + ) + large_total_tokens, _ = count_tokens(large_results_dir, total=12) + + total_tokens_df = make_total_tokens_table(default_total_tokens, large_total_tokens) + + # save the tables + total_tokens_df.to_csv(save_dir / "total_tokens.csv") + default_per_sample_tokens_df.to_csv(save_dir / "per_sample_tokens.csv") + + +def main(default_results_dir: Path, large_results_dir: Path, save_dir: Path): + save_dir.mkdir(parents=True, exist_ok=True) + + print("Parsing default dataset results...") + default_results_dict = extract_default_results_dict(default_results_dir) + print("Making default dataset tables...") + make_default_tables(default_results_dict, save_dir) + print("Making default dataset plots...") + make_default_plots(default_results_dict, save_dir) + + print("Parsing large dataset results...") + large_results_dict = extract_large_results_dict(large_results_dir) + print("Making large dataset plot...") + make_large_plot(large_results_dict, save_dir) + + print("Making token count tables...") + make_token_count_tables(default_results_dir, large_results_dir, save_dir) + print("Done.") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Process results") + parser.add_argument( + "--default_results_dir", + type=str, + help="Path to directory containing .log files from experiments on default dataset", + ) + parser.add_argument( + "--large_results_dir", + type=str, + help="Path to directory containing .log files from experiments on large dataset", + ) + parser.add_argument( + "--save_dir", type=str, help="Path to directory to save plots and tables to" + ) + + args = parser.parse_args() + + main( + Path(args.default_results_dir), + Path(args.large_results_dir), + Path(args.save_dir), + ) diff --git a/evals/elsuite/identifying_variables/scripts/plotting_utils.py b/evals/elsuite/identifying_variables/scripts/plotting_utils.py new file mode 100644 index 0000000000..1c80aab042 --- /dev/null +++ b/evals/elsuite/identifying_variables/scripts/plotting_utils.py @@ -0,0 +1,163 @@ +from typing import Dict, Iterable, List +from pathlib import Path + +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns + + +renderers_of_interest = ["csv", "language-corrset"] + +renderer_to_label = { + "csv": "CSV observations", + "language-corrset": "Correlation set", +} + +cmap = plt.get_cmap("Paired") +colors = np.array([cmap(i) for i in range(len(renderers_of_interest))]) +renderer_to_color = {r: c for r, c in zip(renderers_of_interest, colors)} + +solver_to_label = { + "generation/direct/gpt-3.5-turbo": "Direct gpt-3.5-turbo", + "generation/cot/gpt-3.5-turbo": "CoT gpt-3.5-turbo", + "generation/hhh/gpt-4-base": "HHH gpt-4-base", + "generation/cot_hhh/gpt-4-base": "CoT HHH gpt-4-base", + "generation/direct/gpt-4-1106-preview": "Direct gpt-4-1106-preview", + "generation/cot/gpt-4-1106-preview": "CoT gpt-4-1106-preview", + "generation/cot/mixtral-8x7b-instruct": "CoT mixtral-8x7b-instruct\n(Correlation set only)", + "generation/cot/llama-2-70b-chat": "CoT llama-2-70b-chat\n(Correlation set only)", + "generation/cot/gemini-pro": "CoT gemini-pro-1.0\n(Correlation set only)", + "identifying_variables/random": "Random baseline", + "identifying_variables/noctrl": "NoCtrl baseline", +} + +baseline_to_linestyle = { + "identifying_variables/random": "--", + "identifying_variables/noctrl": "-.", +} + +cmap = plt.get_cmap("Set2") +bline_colors = np.array( + [cmap(i) for i in range(0, len(baseline_to_linestyle.keys()) + 0)] +) +baseline_to_color = { + key: color for key, color in zip(baseline_to_linestyle.keys(), bline_colors) +} + + +def plot_solver_bars( + bar_solvers: List[str], + baseline_solvers: List[str], + metric_results: Dict, + metric_label: str, + fig_height: int, + output_path: Path, +): + """ + Plots a side-by-side bar plot of the metric results, showing the + solvers on the x axis and the metric value on the y axis. + + Args: + bar_solvers: The names of solvers to plot. + baseline_solvers: The names of the baseline solvers to plot. + metric_results: A dictionary with k: v of format solver : {mean: value, sem: value} + metric_label: The label for the y axis + fig_height: the height of the figure in inches + output_path: the path to save the figure to + """ + sns.set_context("paper") + sns.set_style("whitegrid") + + bar_width = 0.3 + positions = np.arange(len(bar_solvers)) + + f, ax = plt.subplots(1, 1, dpi=300, figsize=(9, fig_height)) + + for i, renderer in enumerate(renderers_of_interest): + bars = [ + metric_results["mean"][solver][renderer]["without tree"] + for solver in bar_solvers + ] + errors = [ + metric_results["sem"][solver][renderer]["without tree"] + for solver in bar_solvers + ] + + ax.bar( + positions + bar_width * i, + bars, + bar_width, + yerr=errors, + label=renderer_to_label[renderer], + color=renderer_to_color[renderer], + ) + + for baseline_solver in baseline_solvers: + mean = metric_results["mean"][baseline_solver]["corrset"]["without tree"] + sem = metric_results["sem"][baseline_solver]["corrset"]["without tree"] + ax.axhline( + mean, + label=solver_to_label[baseline_solver], + color=baseline_to_color[baseline_solver], + linestyle=baseline_to_linestyle[baseline_solver], + ) + ax.axhspan( + mean - sem, mean + sem, alpha=0.1, color=baseline_to_color[baseline_solver] + ) + + ax.set_xticks( + positions + bar_width / 2, + [solver_to_label[s] for s in bar_solvers], + rotation=45, + ha="right", + ) + ax.tick_params( + axis="x", which="both", bottom=True + ) # Show both major and minor xticks + ax.set_ylabel(metric_label) + ax.set_ylim(-0.005, 1) + ax.xaxis.grid(False) + ax.legend() + f.set_tight_layout(True) + plt.savefig(output_path, dpi=300, bbox_inches="tight") + + +def plot_difficulty_bars(results_dict: Dict, bins: Iterable[int], output_path: Path): + sns.set_context("paper") + sns.set_style("whitegrid") + + f, ax = plt.subplots(1, 1, dpi=300, figsize=(7, 4)) + + positions = np.arange(len(bins)) + bar_width = 0.4 + + for i, key in enumerate(sorted(results_dict.keys())): + solver, renderer = key.split(";") + bars = [results_dict[key][bbin]["mean"] for bbin in bins] + errors = [results_dict[key][bbin]["sem"] for bbin in bins] + if solver == "generation/direct/gpt-4-1106-preview": + label = renderer_to_label[renderer] + color = renderer_to_color[renderer] + ax.bar( + positions + bar_width * i, + bars, + bar_width, + yerr=errors, + label=label, + color=color, + ) + + ax.set_xlabel("Number of necessary control variables") + ax.set_ylabel("Control Variable Retrieval nDCG*") + + ax.set_xlim(-0.3, 8.7) + ax.set_ylim(0, 1) + ax.xaxis.grid(False) + ax.legend() + ax.set_xticks(positions + bar_width / 2, bins) + f.set_tight_layout(True) + plt.savefig( + output_path, + dpi=300, + bbox_inches="tight", + ) diff --git a/evals/elsuite/identifying_variables/scripts/run_experiments.sh b/evals/elsuite/identifying_variables/scripts/run_experiments.sh new file mode 100755 index 0000000000..fae5ceb93b --- /dev/null +++ b/evals/elsuite/identifying_variables/scripts/run_experiments.sh @@ -0,0 +1,105 @@ +#!/bin/bash + +# Function to display usage +usage() { + echo "Usage: $0 -s size -l logdir" + echo " -s size Specify the size of the experiments (options: 'balanced-hypotheses', 'balanced-ctrl', 'balanced-hypotheses-large', 'balanced-ctrl-large')" + echo " -l logdir Specify the directory for log files" + exit 1 +} + +# Check if no arguments were provided +if [ $# -eq 0 ]; then + usage + exit 1 +fi + +# Parse command-line options +while getopts 's:l:' flag; do + case "${flag}" in + s) size=${OPTARG} ;; + l) logdir=${OPTARG} ;; + *) usage ;; + esac +done + +# Check if mandatory arguments were provided +if [ -z "$size" ] || [ -z "$logdir" ]; then + usage + exit 1 +fi + +logdirbase=$logdir +NUM_REPEATS=3 + +# Function to run experiments +run_experiments() { + local size=$1 + local logpathbase="${logdirbase}/${size}" + local start_time=$SECONDS + + # Define RENDERERS and SOLVERS array based on size + declare -a RENDERERS + declare -a SOLVERS + if [ "$size" == "balanced-hypotheses" ]; then + RENDERERS=("markdown" "csv" "json" "language-tabular" "language-corrset" "corrset") + SOLVERS=("generation/direct/gpt-3.5-turbo" + "generation/cot/gpt-3.5-turbo" + "generation/hhh/gpt-4-base" + "generation/cot_hhh/gpt-4-base" + "generation/direct/gpt-4-1106-preview" + "generation/cot/gpt-4-1106-preview") + elif [ "$size" == "balanced-ctrl" ]; then + RENDERERS=("csv" "language-corrset") + SOLVERS=("generation/direct/gpt-3.5-turbo" + "generation/cot/gpt-3.5-turbo" + "generation/hhh/gpt-4-base" + "generation/cot_hhh/gpt-4-base" + "generation/direct/gpt-4-1106-preview" + "generation/cot/gpt-4-1106-preview") + else + RENDERERS=("csv" "language-corrset") + SOLVERS=("generation/direct/gpt-4-1106-preview") + fi + + # Main loop + for ((i = 1; i <= NUM_REPEATS; i++)); do + for solver in "${SOLVERS[@]}"; do + for renderer in "${RENDERERS[@]}"; do + run_solver $solver $renderer $size $i "$logpathbase" + done + done + run_solver "identifying_variables/random" "corrset" $size $i "$logpathbase" + run_solver "identifying_variables/noctrl" "corrset" $size $i "$logpathbase" + done + + local end_time=$SECONDS + echo "Done running experiments for $size size, all logs in $logpathbase" + echo "Total execution time: $((end_time - start_time)) seconds." +} + +# Function to run a single solver +run_solver() { + local solver=$1 + local renderer=$2 + local size=$3 + local seed=$4 + local logpathbase=$5 + local solver_dotted=${solver//\//.} + + local record_path="${logpathbase}/${solver_dotted}_${renderer}_${size}_${seed}" + echo "Running $solver with $renderer renderer and $size data size; seed $seed" + + local sub_start_time=$(date +%s) + oaieval "$solver" "identifying_variables.${renderer}.${size}" --record_path "$record_path.log" --seed $seed + local sub_end_time=$(date +%s) + echo "${solver_dotted}_${renderer}_${size} execution time: $((sub_end_time - sub_start_time)) seconds." + + skip_tree_solvers=("identifying_variables/random" "identifying_variables/noctrl") + if [[ ! "${skip_tree_solvers[@]}" =~ "$solver" ]] && [ "$size" == "balanced-hypotheses" ]; then + echo "Now repeating with show_tree=True" + oaieval "$solver" "identifying_variables.${renderer}.${size}" --extra_eval_params show_tree=True --record_path "${record_path}_tree.log" --seed $seed + fi +} + +run_experiments "${size}" diff --git a/evals/elsuite/identifying_variables/scripts/table_utils.py b/evals/elsuite/identifying_variables/scripts/table_utils.py new file mode 100644 index 0000000000..3991cd469b --- /dev/null +++ b/evals/elsuite/identifying_variables/scripts/table_utils.py @@ -0,0 +1,66 @@ +from typing import Dict, List +from pathlib import Path + +import numpy as np +import pandas as pd + + +def make_main_metric_table( + results_dict: Dict, + metric: str, + solvers: List[str], + renderers: List[str], + save_dir: Path, +): + """ + Makes and saves a table containing the information of performance of + each solver for each renderer for each variant of the eval on + a given metric. + - Table rows are solvers; they are multi-rows, so each row has two subrows: with + tree and without tree + - Table columns are renderers; they are multi-columns, so each column has two + subcolumns: mean and sem (standard error of the mean) + + Args: + results_dict: dictionary containing the results of the eval. See + `initialize_default_results_dict` and `populate_default_results_dict` in + `process_results.py`. + metric: the name of the metric we want to make the table for + solvers: list of solvers we want to include in the table + renderers: list of renderers we want to include in the table + save_dir: directory to save the table in (as a CSV file) + """ + + # only keep keep metric in results_dict + filtered_results_dict = results_dict[metric] + # flatten into tuples + data_tuples = [] + for stat, solver_data in filtered_results_dict.items(): + for solver, renderer_data in solver_data.items(): + for renderer, tree_data in renderer_data.items(): + for tree_type, value in tree_data.items(): + if value is not None: + data_tuples.append((solver, tree_type, renderer, stat, value)) + + df = pd.DataFrame( + data_tuples, columns=["Solver", "Tree", "Renderer", "Stat", "Value"] + ) + df = df.pivot_table( + index=["Solver", "Tree"], columns=["Renderer", "Stat"], values="Value" + ) + # sorting by solvers, renderers (for some reason ordering is lost in the above process) + new_index = [ + (solver, tree) for solver in solvers for tree in ["with tree", "without tree"] + ] + new_columns = pd.MultiIndex.from_product( + [renderers, df.columns.levels[1]], names=df.columns.names + ) + df = df.reindex(new_index, columns=new_columns) + + # delete the with tree rows for the treeless solvers + for solver in solvers[-2:]: + df.drop((solver, "with tree"), inplace=True) + + # save table + save_path = save_dir / f"{metric}_table.csv" + df.to_csv(save_path) diff --git a/evals/elsuite/identifying_variables/solvers.py b/evals/elsuite/identifying_variables/solvers.py new file mode 100644 index 0000000000..c6010c74da --- /dev/null +++ b/evals/elsuite/identifying_variables/solvers.py @@ -0,0 +1,48 @@ +import random + +from evals.solvers.solver import Solver, SolverResult +from evals.task_state import TaskState + + +class RandomSolver(Solver): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _solve(self, task_state: TaskState) -> SolverResult: + valid_hyp = random.uniform(0, 1) < 0.5 + + variables = task_state.current_state["variables"] + n_vars_to_sample = random.randint(2, len(variables)) + ind_var, dep_var, *ctrl_vars = random.sample(variables, n_vars_to_sample) + if len(ctrl_vars) == 0: + ctrl_vars = "none" + else: + ctrl_vars = ", ".join(ctrl_vars) + + solver_string = f"[@ANSWER valid_hyp: {valid_hyp}; independent: {ind_var}; dependent: {dep_var}; control: {ctrl_vars}]" + + return SolverResult(output=solver_string) + + +class NoCtrl(Solver): + """ + Solver that always returns no control variables + (i.e. "none", interpreted as an empty list by the eval) + what it returns for the other variables is arbitrary + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _solve(self, task_state: TaskState) -> SolverResult: + # we don't care about valid_hyp and ind/dep vars for this solver + # it's only used for the ctrl variables subtask + valid_hyp = True + variables = task_state.current_state["variables"] + ind_var, dep_var = random.sample(variables, 2) + + # it just always returns no control variables + ctrl_vars = "none" + solver_string = f"[@ANSWER valid_hyp: {valid_hyp}; independent: {ind_var}; dependent: {dep_var}; control: {ctrl_vars}]" + + return SolverResult(output=solver_string) diff --git a/evals/elsuite/identifying_variables/structs.py b/evals/elsuite/identifying_variables/structs.py new file mode 100644 index 0000000000..90b47b96b0 --- /dev/null +++ b/evals/elsuite/identifying_variables/structs.py @@ -0,0 +1,49 @@ +"""Custom data structures for the eval""" +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple + +import networkx as nx + + +@dataclass +class Answer: + valid_hypothesis: bool + ind_var: Optional[str] + dep_var: Optional[str] + ctrl_vars: Optional[List[str]] + + +@dataclass +class Sample: + """ + A sample of the dataset for the eval. + + Args: + variable_metadata (Dict) : A dictionary mapping each variable name to its metadata. + Each variable's metadata is a dictionary containing: + - 'gen_method': A dictionary specifying the generation method for the + variable, including: + - 'name': Name of the latent function or distribution. + - 'input_x': Name of the input variable, if applicable. + - 'kwargs': Additional arguments for the latent function. + - 'corrs': A set of variables correlated with this variable. + hypotheses (nx.DiGraph): A directed acyclic graph (DAG) representing the hypotheses. + target_hypothesis (Tuple[str, str]) A tuple (independent_variable, dependent_variable) + representing the hypothesis of interest. + sample_metadata (Dict): A dictionary with additional metadata, including: + - 'num_obs_samples': Number of observations generated per variable. + - 'snr': Signal-to-noise ratio applied to the observations. + causal_graph (nx.DiGraph): A randomly generated DAG representing the underlying + causal relationships among variables. Represented as nx.DiGraph. + gold_label (Answer): The gold label for the sample. + num_not_ctrl (Optional[int]): The number of variables not controlled for. None + if the hypothesis is invalid. + """ + + variable_metadata: Dict + hypotheses: nx.DiGraph + target_hypothesis: Tuple[str, str] + sample_metadata: Dict + causal_graph: nx.DiGraph + gold_label: Answer + num_not_ctrl: Optional[int] diff --git a/evals/elsuite/identifying_variables/utils.py b/evals/elsuite/identifying_variables/utils.py new file mode 100644 index 0000000000..6918926bdf --- /dev/null +++ b/evals/elsuite/identifying_variables/utils.py @@ -0,0 +1,91 @@ +import re +from typing import Dict + +import networkx as nx +import numpy as np + +from evals.elsuite.identifying_variables.structs import Answer, Sample +from evals.solvers.solver import SolverResult + + +def parse_solver_preds(solver_result: SolverResult) -> Answer: + solver_string = solver_result.output.strip().lower() + + pattern = ( + r"\[@answer " # Matches the beginning of the answer + r"valid_hyp: (true|false|True|False)" # valid hyp part + r"(?:; independent: ([^;]*))?" # Optionally matches the independent part + r"(?:; dependent: ([^;]*))?" # Optionally matches the dependent part + r"(?:; control: ([^\]]*))?" # Optionally matches the control part + r"\]" # Matches the end of the answer + ) + + match = re.search(pattern, solver_string) + + if match: + valid_hyp = match.group(1).lower() == "true" + if not valid_hyp: + return Answer( + valid_hypothesis=False, + ind_var=None, + dep_var=None, + ctrl_vars=None, + ) + ind_var = match.group(2) + ind_var = ind_var if ind_var is not None else "WRONG" + dep_var = match.group(3) + dep_var = dep_var if dep_var is not None else "WRONG" + ctrl_vars = match.group(4) + if ctrl_vars is not None: + ctrl_vars = ctrl_vars.split(",") + ctrl_vars = [var.strip() for var in ctrl_vars] + if ctrl_vars[0].lower().strip("\"'`«»<>") == "none": + ctrl_vars = [] + else: + ctrl_vars = ["WRONG"] + return Answer( + valid_hypothesis=True, + ind_var=ind_var, + dep_var=dep_var, + ctrl_vars=ctrl_vars, + ) + else: + raise ValueError("Invalid solver output") + + +def sample_serializer(obj): + """ + Custom serializer to pass to json.dumps when + saving a sample dictionary to jsonl + """ + if isinstance(obj, set): + return list(obj) + elif isinstance(obj, nx.DiGraph): + return nx.to_dict_of_lists(obj) + elif isinstance(obj, np.integer): + return int(obj) + elif isinstance(obj, np.floating): + return float(obj) + + +def json_to_sample(serialized_sample: Dict) -> Sample: + """Reads sample from jsonl into Sample dataclass""" + hypotheses = nx.from_dict_of_lists(serialized_sample["hypotheses"], create_using=nx.DiGraph) + causal_graph = nx.from_dict_of_lists(serialized_sample["causal_graph"], create_using=nx.DiGraph) + gold_label = Answer(**serialized_sample["gold_label"]) + + # convert corrs in variable_metadata from lists to sets + for var in serialized_sample["variable_metadata"]: + serialized_sample["variable_metadata"][var]["corrs"] = set( + serialized_sample["variable_metadata"][var]["corrs"] + ) + + return Sample( + variable_metadata=serialized_sample["variable_metadata"], + hypotheses=hypotheses, + target_hypothesis=serialized_sample["target_hypothesis"], + sample_metadata=serialized_sample["sample_metadata"], + causal_graph=causal_graph, + gold_label=gold_label, + num_not_ctrl=serialized_sample["num_not_ctrl"], + ) diff --git a/evals/registry/data/identifying_variables/balanced_ctrl_vars.jsonl b/evals/registry/data/identifying_variables/balanced_ctrl_vars.jsonl new file mode 100644 index 0000000000..c29a8ee65d --- /dev/null +++ b/evals/registry/data/identifying_variables/balanced_ctrl_vars.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e9429fe712578ae4298e012cc374198bf83cf968115004dc00d24e42ebdc4f1d +size 12525123 diff --git a/evals/registry/data/identifying_variables/balanced_hypotheses.jsonl b/evals/registry/data/identifying_variables/balanced_hypotheses.jsonl new file mode 100644 index 0000000000..cb05f29c53 --- /dev/null +++ b/evals/registry/data/identifying_variables/balanced_hypotheses.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e92ee79ee832d7f6f40e55cad82fe26100ea3c1ca1faac2f606a046ef4a09b79 +size 7554989 diff --git a/evals/registry/evals/identifying_variables.yaml b/evals/registry/evals/identifying_variables.yaml new file mode 100644 index 0000000000..32f4ecbafa --- /dev/null +++ b/evals/registry/evals/identifying_variables.yaml @@ -0,0 +1,136 @@ +identifying_variables: + id: identifying_variables.language-corrset.balanced-ctrl + metrics: + [ + "ctrl_nDCG", + "ctrl_recall", + "ctrl_fallout", + "hyp_valid_acc", + "ind_acc", + "dep_acc", + "violation_rate", + ] + description: + "Evaluate the model's ability of identifying the right experimental + variables for testing a given hypothesis." + +# Balanced-hypotheses datasets + +identifying_variables.markdown.balanced-hypotheses: + class: evals.elsuite.identifying_variables.eval:IdentifyingVariables + args: + samples_jsonl: identifying_variables/balanced_hypotheses.jsonl + n_samples: 500 + renderer: markdown +identifying_variables.markdown.balanced-hypotheses-large: + class: evals.elsuite.identifying_variables.eval:IdentifyingVariables + args: + samples_jsonl: identifying_variables/balanced_hypotheses.jsonl + renderer: markdown + group_metrics: true + +identifying_variables.csv.balanced-hypotheses: + class: evals.elsuite.identifying_variables.eval:IdentifyingVariables + args: + samples_jsonl: identifying_variables/balanced_hypotheses.jsonl + n_samples: 500 + renderer: csv +identifying_variables.csv.balanced-hypotheses-large: + class: evals.elsuite.identifying_variables.eval:IdentifyingVariables + args: + samples_jsonl: identifying_variables/balanced_hypotheses.jsonl + renderer: csv + group_metrics: true + +identifying_variables.json.balanced-hypotheses: + class: evals.elsuite.identifying_variables.eval:IdentifyingVariables + args: + samples_jsonl: identifying_variables/balanced_hypotheses.jsonl + n_samples: 500 + renderer: json +identifying_variables.json.balanced-hypotheses-large: + class: evals.elsuite.identifying_variables.eval:IdentifyingVariables + args: + samples_jsonl: identifying_variables/balanced_hypotheses.jsonl + renderer: json + group_metrics: true + +identifying_variables.language-tabular.balanced-hypotheses: + class: evals.elsuite.identifying_variables.eval:IdentifyingVariables + args: + samples_jsonl: identifying_variables/balanced_hypotheses.jsonl + n_samples: 500 + renderer: language-tabular +identifying_variables.language-tabular.balanced-hypotheses-large: + class: evals.elsuite.identifying_variables.eval:IdentifyingVariables + args: + samples_jsonl: identifying_variables/balanced_hypotheses.jsonl + renderer: language-tabular + group_metrics: true + +identifying_variables.language-corrset.balanced-hypotheses: + class: evals.elsuite.identifying_variables.eval:IdentifyingVariables + args: + samples_jsonl: identifying_variables/balanced_hypotheses.jsonl + n_samples: 500 + renderer: language-corrset +identifying_variables.language-corrset.balanced-hypotheses-large: + class: evals.elsuite.identifying_variables.eval:IdentifyingVariables + args: + samples_jsonl: identifying_variables/balanced_hypotheses.jsonl + renderer: language-corrset + group_metrics: true + +identifying_variables.corrset.balanced-hypotheses: + class: evals.elsuite.identifying_variables.eval:IdentifyingVariables + args: + samples_jsonl: identifying_variables/balanced_hypotheses.jsonl + n_samples: 500 + renderer: corrset +identifying_variables.corrset.balanced-hypotheses-large: + class: evals.elsuite.identifying_variables.eval:IdentifyingVariables + args: + samples_jsonl: identifying_variables/balanced_hypotheses.jsonl + renderer: corrset + group_metrics: true + +# Balanced-control datasets + +identifying_variables.csv.balanced-ctrl: + class: evals.elsuite.identifying_variables.eval:IdentifyingVariables + args: + samples_jsonl: identifying_variables/balanced_ctrl_vars.jsonl + n_samples: 500 + renderer: csv +identifying_variables.csv.balanced-ctrl-large: + class: evals.elsuite.identifying_variables.eval:IdentifyingVariables + args: + samples_jsonl: identifying_variables/balanced_ctrl_vars.jsonl + renderer: csv + group_metrics: true + +identifying_variables.language-corrset.balanced-ctrl: + class: evals.elsuite.identifying_variables.eval:IdentifyingVariables + args: + samples_jsonl: identifying_variables/balanced_ctrl_vars.jsonl + n_samples: 500 + renderer: language-corrset +identifying_variables.language-corrset.balanced-ctrl-large: + class: evals.elsuite.identifying_variables.eval:IdentifyingVariables + args: + samples_jsonl: identifying_variables/balanced_ctrl_vars.jsonl + renderer: language-corrset + group_metrics: true + +identifying_variables.corrset.balanced-ctrl: + class: evals.elsuite.identifying_variables.eval:IdentifyingVariables + args: + samples_jsonl: identifying_variables/balanced_ctrl_vars.jsonl + n_samples: 500 + renderer: corrset +identifying_variables.corrset.balanced-ctrl-large: + class: evals.elsuite.identifying_variables.eval:IdentifyingVariables + args: + samples_jsonl: identifying_variables/balanced_ctrl_vars.jsonl + renderer: corrset + group_metrics: true diff --git a/evals/registry/solvers/identifying_variables.yaml b/evals/registry/solvers/identifying_variables.yaml new file mode 100644 index 0000000000..aa6108febc --- /dev/null +++ b/evals/registry/solvers/identifying_variables.yaml @@ -0,0 +1,5 @@ +identifying_variables/random: + class: evals.elsuite.identifying_variables.solvers:RandomSolver + +identifying_variables/noctrl: + class: evals.elsuite.identifying_variables.solvers:NoCtrl diff --git a/evals/utils/log_utils.py b/evals/utils/log_utils.py index d54a846f41..787b16d79a 100644 --- a/evals/utils/log_utils.py +++ b/evals/utils/log_utils.py @@ -31,7 +31,7 @@ def extract_final_results(path: Path) -> dict: raise ValueError(f"Could not find final_report in {path}") -def extract_individual_results(path: Path) -> list[dict]: +def extract_individual_results(path: Path, type_string: str = "metrics") -> list[dict]: """ Given a path to a log file, grab all the individual sample results. """ @@ -42,7 +42,7 @@ def extract_individual_results(path: Path) -> list[dict]: try: loaded_line = json.loads(line) if "type" in loaded_line: - if loaded_line["type"] == "metrics": + if loaded_line["type"] == type_string: all_data.append(loaded_line["data"]) except json.decoder.JSONDecodeError: print(f"Skipping line: {line}") diff --git a/pyproject.toml b/pyproject.toml index 4c4e6cbfa9..057f38b9eb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ dependencies = [ "jiwer", "seaborn", "statsmodels", + "networkx", "chess", ]