Skip to content

Commit

Permalink
attempt to add tests for exclusions
Browse files Browse the repository at this point in the history
  • Loading branch information
sjadler2004 committed Oct 18, 2023
1 parent 4aacc1c commit f1054bf
Showing 1 changed file with 17 additions and 8 deletions.
25 changes: 17 additions & 8 deletions evals/elsuite/basic/includes_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,35 +11,44 @@


@mark.parametrize(
"completion, ideal, expected_match, ignore_case",
"completion, ideal, expected_match, ignore_case, use_exclusions, excluded_terms",
[
("world", "world", True, False),
("world", "wOrLd", True, True),
("world", ["world"], True, False),
("world", ["foo", "bar"], False, False),
("world", ["worldfoo", "worldbar"], False, False),
("world", "world", True, False, False, None),
("world", "wOrLd", True, True, False, None),
("world", ["world"], True, False, False, None),
("world", ["foo", "bar"], False, False, False, None),
("world", ["worldfoo", "worldbar"], False, False, False, None),
# test for exclusions: does including an excludable word lead to False match, on what would otherwise be True match?
("world exclusion", "world", False, False, True, ["exlusion", "excluded"]),
# change the 2nd word in completion from an excluded word, to make sure this would be true otherwise
("world okay", "world", True, False, True, ["exlusion", "excluded"]),
],
)
def test_eval_sample(
completion: str,
ideal: Union[str, list[str]],
expected_match: bool,
ignore_case: bool,
use_exclusions: bool = False,
excluded_terms: Union[str, list[str]] = []
):
eval = Includes(
completion_fns=[TestCompletionFn(completion)],
samples_jsonl="",
eval_registry_path=Path("."),
ignore_case=ignore_case,
use_exclusions=use_exclusions,
excluded_terms=excluded_terms,
)

recorder = DummyRecorder(None)
with recorder.as_default_recorder("x"), patch.object(
recorder, "record_match", wraps=recorder.record_match
) as record_match:
eval.eval_sample(dict(input="Hello", ideal=ideal), None)
sample_dict = dict(input="Hello", ideal=ideal, exclude=excluded_terms)
eval.eval_sample(sample_dict, None)
record_match.assert_called_once_with(
expected_match, expected=ideal, picked=completion, sampled=completion
expected_match, expected=ideal, picked=completion, sampled=completion,
)


Expand Down

0 comments on commit f1054bf

Please sign in to comment.