Skip to content

Commit

Permalink
Merge pull request #362 from EleutherAI/cleanup-for-release
Browse files Browse the repository at this point in the history
Cleanup `README.md` and package deps
  • Loading branch information
StellaAthena authored Dec 7, 2022
2 parents fdd3dbc + 1e5d55d commit 1d8107b
Show file tree
Hide file tree
Showing 26 changed files with 434 additions and 440 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install flake8 pytest pytest-cov
pip install -e .[dev]
pip install -e .[dev,multilingual]
# Install optional git dependencies
pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Lint with flake8
run: |
Expand Down
453 changes: 64 additions & 389 deletions README.md

Large diffs are not rendered by default.

268 changes: 268 additions & 0 deletions docs/task_table.md

Large diffs are not rendered by default.

26 changes: 24 additions & 2 deletions lm_eval/tasks/translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,20 @@
from lm_eval.base import Task, rf
from typing import List

try:
import nagisa

HAS_NAGISA = True
except ImportError:
HAS_NAGISA = False

try:
import jieba

HAS_JIEBA = True
except ImportError:
HAS_JIEBA = False


_CITATION = """
@inproceedings{post-2018-call,
Expand Down Expand Up @@ -63,14 +77,22 @@ def version_of(dataset, language_pair):

def zh_split(zh_text: List[str]) -> List[str]:
"""Chinese splitting"""
import jieba
if not HAS_JIEBA:
raise ImportError(
"Chinese text splitting requires the `jieba` package. "
"Please install it with:\npip install jieba"
)

return [" ".join(jieba.cut(txt.strip())) for txt in zh_text]


def ja_split(ja_text: List[str]) -> List[str]:
"""Japanese splitting"""
import nagisa
if not HAS_NAGISA:
raise ImportError(
"Japanese text splitting requires the `nagisa` package. "
"Please install it with:\npip install nagisa"
)

return [" ".join(nagisa.tagging(txt.strip()).words) for txt in ja_text]

Expand Down
14 changes: 14 additions & 0 deletions lm_eval/tasks/truthfulqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@
from lm_eval.metrics import mean


try:
import bleurt

HAS_BLEURT = True
except ImportError:
HAS_BLEURT = False


_CITATION = """
@misc{lin2021truthfulqa,
title={TruthfulQA: Measuring How Models Mimic Human Falsehoods},
Expand Down Expand Up @@ -164,6 +172,12 @@ class TruthfulQAGeneration(Task):

def __init__(self):
super().__init__()
if not HAS_BLEURT:
raise ImportError(
"`TruthfulQAGeneration` requires the `bleurt` package. Please install it with:\n"
"pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt"
"\nWARNING: Installing any other version of bleurt may result in different results."
)
self.bleurt = datasets.load_metric("bleurt")

def has_training_docs(self):
Expand Down
3 changes: 2 additions & 1 deletion lm_eval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import functools
import inspect
import sys
import pytest
from typing import List


Expand Down Expand Up @@ -187,6 +186,8 @@ def run_task_tests(task_list: List[str]):
"""
Find the package root and run the tests for the given tasks
"""
import pytest

package_root = find_test_root(start_path=pathlib.Path(__file__))
task_string = " or ".join(task_list)
args = [
Expand Down
61 changes: 40 additions & 21 deletions scripts/make_table_tasks.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,52 @@
"""
Usage:
python make_table_tasks.py --output <markdown_filename>
"""
import argparse
import logging
from lm_eval import tasks
from pytablewriter import MarkdownTableWriter

writer = MarkdownTableWriter()
writer.headers = ["Task Name", "Train", "Val", "Test", "Val/Test Docs", "Metrics"]

values = []
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def chk(tf):
def check(tf):
if tf:
return "✓"
else:
return " "


for tname, Task in tasks.TASK_REGISTRY.items():
task = Task()

v = [
tname,
chk(task.has_training_docs()),
chk(task.has_validation_docs()),
chk(task.has_test_docs()),
len(list(task.test_docs() if task.has_test_docs() else task.validation_docs())),
", ".join(task.aggregation().keys()),
]
print(v)
values.append(v)

writer.value_matrix = values

print(writer.dumps())
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--output", type=str, default="task_table.md")
args = parser.parse_args()

writer = MarkdownTableWriter()
writer.headers = ["Task Name", "Train", "Val", "Test", "Val/Test Docs", "Metrics"]
values = []

tasks = tasks.TASK_REGISTRY.items()
tasks = sorted(tasks, key=lambda x: x[0])
for tname, Task in tasks:
task = Task()
v = [
tname,
check(task.has_training_docs()),
check(task.has_validation_docs()),
check(task.has_test_docs()),
len(
list(
task.test_docs() if task.has_test_docs() else task.validation_docs()
)
),
", ".join(task.aggregation().keys()),
]
logger.info(v)
values.append(v)
writer.value_matrix = values
table = writer.dumps()
with open(args.output, "w") as f:
f.write(table)
37 changes: 16 additions & 21 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,36 +14,31 @@
url="https://github.com/EleutherAI/lm-evaluation-harness",
packages=setuptools.find_packages(),
classifiers=[
"Development Status :: 3 - Alpha",
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
],
python_requires=">=3.6",
install_requires=[
"datasets>=2.0.0",
"click>=7.1",
"jsonlines",
"numexpr",
"openai>=0.6.4",
"pybind11>=2.6.2",
"pycountry",
"pytablewriter",
"rouge-score>=0.0.4",
"sacrebleu==1.5.0",
"scikit-learn>=0.24.1",
"sqlitedict",
"torch>=1.7",
"tqdm-multiprocess",
"transformers>=4.1",
"sqlitedict==1.6.0",
"pytablewriter==0.58.0",
"sacrebleu==1.5.0",
"rouge-score==0.0.4",
"pycountry==20.7.3",
"numexpr>=2.7.2",
"lm_dataformat==0.0.20",
"pybind11==2.6.2",
"tqdm-multiprocess==0.0.11",
"zstandard==0.15.2",
"jsonlines==2.0.0",
"mock==4.0.3",
"openai==0.6.4",
"jieba==0.42.1",
"nagisa==0.2.7",
"bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt",
],
dependency_links=[
"https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt",
"zstandard",
],
extras_require={"dev": ["pytest", "black", "pre-commit"]},
extras_require={
"dev": ["black", "flake8", "pre-commit", "pytest", "pytest-cov"],
"multilingual": ["nagisa>=0.2.7", "jieba>=0.42.1"],
},
)
3 changes: 2 additions & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,9 @@ def textsynth_mock_completion(**kwargs):
import requests

os.makedirs("tests/testdata", exist_ok=True)
hash_kwargs = {k: v for k, v in kwargs.items() if k != "headers"}
hash = hashlib.sha256(
json.dumps(kwargs, sort_keys=True).encode("utf-8")
json.dumps(hash_kwargs, sort_keys=True).encode("utf-8")
).hexdigest()
fname = f"tests/testdata/textsynth_test_{hash}.pkl"

Expand Down
3 changes: 0 additions & 3 deletions tests/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@
@pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items())
def test_basic_interface(taskname, task_class):
print("Evaluating task", taskname)
# dl = task_class.download
# task_class.download = MagicMock()
task = task_class()
# task_class.download = dl

assert task.has_training_docs() in [True, False]
assert task.has_validation_docs() in [True, False]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_version_stable.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def flatten(d, parent_key="", sep="."):
items = []
for k, v in d.items():
new_key = parent_key + sep + k if parent_key else k
if isinstance(v, collections.MutableMapping):
if isinstance(v, collections.abc.MutableMapping):
items.extend(flatten(v, new_key, sep=sep).items())
else:
items.append((new_key, v))
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 comments on commit 1d8107b

Please sign in to comment.