Skip to content

Commit c003529

Browse files
committed
1 parent 9007f36 commit c003529

54 files changed

Lines changed: 9619 additions & 0 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

lm_eval/__init__.py

Whitespace-only changes.

lm_eval/base.py

Lines changed: 891 additions & 0 deletions
Large diffs are not rendered by default.

lm_eval/evaluator.py

Lines changed: 325 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,325 @@
1+
import collections
2+
import itertools
3+
import numpy as np
4+
import random
5+
import lm_eval.metrics
6+
import lm_eval.models
7+
import lm_eval.tasks
8+
import lm_eval.base
9+
from lm_eval.utils import positional_deprecated
10+
11+
12+
@positional_deprecated
13+
def simple_evaluate(
14+
model,
15+
model_args=None,
16+
tasks=[],
17+
num_fewshot=0,
18+
batch_size=None,
19+
device=None,
20+
no_cache=False,
21+
limit=None,
22+
bootstrap_iters=100000,
23+
description_dict=None,
24+
check_integrity=False,
25+
decontamination_ngrams_path=None,
26+
):
27+
28+
"""Instantiate and evaluate a model on a list of tasks.
29+
30+
:param model: Union[str, LM]
31+
Name of model or LM object, see lm_eval.models.get_model
32+
:param model_args: Optional[str]
33+
String arguments for each model class, see LM.create_from_arg_string.
34+
Ignored if `model` argument is a LM object.
35+
:param tasks: list[Union[str, Task]]
36+
List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
37+
:param num_fewshot: int
38+
Number of examples in few-shot context
39+
:param batch_size: int, optional
40+
Batch size for model
41+
:param device: str, optional
42+
PyTorch device (e.g. "cpu" or "cuda:0") for running models
43+
:param no_cache: bool
44+
Whether or not to cache
45+
:param limit: int, optional
46+
Limit the number of examples per task (only use this for testing)
47+
:param bootstrap_iters:
48+
Number of iterations for bootstrap statistics
49+
:param description_dict: dict[str, str]
50+
Dictionary of custom task descriptions of the form: `task_name: description`
51+
:param check_integrity: bool
52+
Whether to run the relevant part of the test suite for the tasks
53+
:return
54+
Dictionary of results
55+
"""
56+
random.seed(1234)
57+
np.random.seed(1234)
58+
59+
assert tasks != [], "No tasks specified"
60+
61+
if isinstance(model, str):
62+
if model_args is None:
63+
model_args = ""
64+
lm = lm_eval.models.get_model(model).create_from_arg_string(
65+
model_args, {"batch_size": batch_size, "device": device}
66+
)
67+
else:
68+
assert isinstance(model, lm_eval.base.LM)
69+
lm = model
70+
71+
if not no_cache:
72+
lm = lm_eval.base.CachingLM(
73+
lm,
74+
"lm_cache/"
75+
+ model
76+
+ "_"
77+
+ model_args.replace("=", "-").replace(",", "_").replace("/", "-")
78+
+ ".db",
79+
)
80+
81+
task_dict = lm_eval.tasks.get_task_dict(tasks)
82+
83+
if check_integrity:
84+
raise NotImplementedError
85+
86+
results = evaluate(
87+
lm=lm,
88+
task_dict=task_dict,
89+
num_fewshot=num_fewshot,
90+
limit=limit,
91+
bootstrap_iters=bootstrap_iters,
92+
description_dict=description_dict,
93+
decontamination_ngrams_path=decontamination_ngrams_path,
94+
)
95+
96+
# add info about the model and few shot config
97+
results["config"] = {
98+
"model": model,
99+
"model_args": model_args,
100+
"num_fewshot": num_fewshot,
101+
"batch_size": batch_size,
102+
"device": device,
103+
"no_cache": no_cache,
104+
"limit": limit,
105+
"bootstrap_iters": bootstrap_iters,
106+
"description_dict": description_dict,
107+
}
108+
109+
return results
110+
111+
112+
decontaminate_suffix = "_decontaminate"
113+
114+
115+
@positional_deprecated
116+
def evaluate(
117+
lm,
118+
task_dict,
119+
provide_description=None,
120+
num_fewshot=0,
121+
limit=None,
122+
bootstrap_iters=100000,
123+
description_dict=None,
124+
decontamination_ngrams_path=None,
125+
):
126+
"""Instantiate and evaluate a model on a list of tasks.
127+
128+
:param lm: obj
129+
Language Model
130+
:param task_dict: dict[str, Task]
131+
Dictionary of tasks. Tasks will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
132+
:param provide_description: bool
133+
Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method
134+
:param num_fewshot: int
135+
Number of examples in few-shot context
136+
:param limit: int, optional
137+
Limit the number of examples per task (only use this for testing)
138+
:param bootstrap_iters:
139+
Number of iterations for bootstrap statistics
140+
:param description_dict: dict[str, str]
141+
Dictionary of custom task descriptions of the form: `task_name: description`
142+
:return
143+
Dictionary of results
144+
"""
145+
# TODO: completely refactor this entire function to not be a huge mess, ideally breaking it down into smaller pieces
146+
147+
# TODO: todo: implement proper description-providing system
148+
assert not provide_description # not implemented.
149+
if provide_description is not None:
150+
# nudge people to not specify it at all
151+
print(
152+
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
153+
)
154+
155+
decontaminate = decontamination_ngrams_path is not None
156+
157+
task_dict_items = [
158+
(name, task)
159+
for name, task in task_dict.items()
160+
if (task.has_validation_docs() or task.has_test_docs())
161+
]
162+
163+
results = collections.defaultdict(dict)
164+
versions = collections.defaultdict(dict)
165+
166+
requests = collections.defaultdict(list)
167+
requests_origin = collections.defaultdict(list)
168+
169+
overlaps = collections.defaultdict(list) # {task_name: contaminated_docs}
170+
171+
# If we ever run into issues where the eval tasks don't fit in memory and we can't afford a machine with bigger
172+
# memory, we can always modify this plumbing to support that, but I didn't want to include it just yet because
173+
# over-engineering is bad (or we could make it write the requests to disk and then read them back out again
174+
# - probably using an sqlite db because of all the moving parts we have
175+
176+
# TODO: we need unit tests & sanity checks or something to ensure that the return of `validation_docs` is stable
177+
docs = {}
178+
179+
docs_for_decontamination = collections.defaultdict(list)
180+
181+
# get lists of each type of request
182+
for task_name, task in task_dict_items:
183+
versions[task_name] = task.VERSION
184+
# default to test doc, fall back to val doc if validation unavailable
185+
# TODO: the test-fallback-to-val system isn't final, we should revisit it at some point
186+
if task.has_test_docs():
187+
task_doc_func = task.test_docs
188+
task_set = "test" # Required for caching in the decontamination
189+
elif task.has_validation_docs():
190+
task_set = "val" # Required for caching in the decontamination
191+
task_doc_func = task.validation_docs
192+
else:
193+
raise RuntimeError("Task has neither test_docs nor validation_docs")
194+
195+
# deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order
196+
task_docs = list(task_doc_func())
197+
rnd = random.Random()
198+
rnd.seed(42)
199+
rnd.shuffle(task_docs)
200+
201+
description = (
202+
description_dict[task_name]
203+
if description_dict and task_name in description_dict
204+
else ""
205+
)
206+
207+
for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)):
208+
209+
if decontaminate and task.should_decontaminate():
210+
docs_for_decontamination[(task_name, task_set)].append(
211+
task.doc_to_decontamination_query(doc)
212+
)
213+
214+
docs[(task_name, doc_id)] = doc
215+
ctx = task.fewshot_context(
216+
doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description
217+
)
218+
reqs = task.construct_requests(doc, ctx)
219+
if not isinstance(reqs, (list, tuple)):
220+
reqs = [reqs]
221+
for i, req in enumerate(reqs):
222+
requests[req.request_type].append(req)
223+
# i: index in requests for a single task instance
224+
# doc_id: unique id that we can get back to a doc using `docs`
225+
requests_origin[req.request_type].append((i, task_name, doc, doc_id))
226+
227+
# Compare all tasks/sets at once to ensure a single training set scan
228+
if decontaminate:
229+
raise NotImplementedError
230+
231+
# all responses for each (task, doc)
232+
process_res_queue = collections.defaultdict(list)
233+
234+
# execute each type of request
235+
for reqtype, reqs in requests.items():
236+
# TODO: right now, this code runs multiple separate LM requests for multiple Requests differing
237+
# only in index. We could implement some kind of caching, but that would be more of a band-aid
238+
# solution. we could also implement some kind of auto-grouping here;
239+
# they should end up next to each other.
240+
241+
print("Running", reqtype, "requests")
242+
resps = getattr(lm, reqtype)([req.args for req in reqs])
243+
resps = [
244+
x if req.index is None else x[req.index] for x, req in zip(resps, reqs)
245+
]
246+
247+
for resp, (i, task_name, doc, doc_id) in zip(resps, requests_origin[reqtype]):
248+
process_res_queue[(task_name, doc_id)].append((i, resp))
249+
250+
vals = collections.defaultdict(list)
251+
252+
# unpack results and sort back in order and return control to Task
253+
for (task_name, doc_id), requests in process_res_queue.items():
254+
requests.sort(key=lambda x: x[0])
255+
requests = [x[1] for x in requests]
256+
257+
task = task_dict[task_name]
258+
doc = docs[(task_name, doc_id)]
259+
260+
metrics = task.process_results(doc, requests)
261+
for metric, value in metrics.items():
262+
vals[(task_name, metric)].append(value)
263+
264+
# Re-use the evaluation for the decontaminated set by just ignoring the overlaps
265+
if decontaminate and task_name in overlaps:
266+
if doc_id not in overlaps[task_name]:
267+
vals[(task_name, metric + decontaminate_suffix)].append(value)
268+
269+
# aggregate results
270+
for (task_name, metric), items in vals.items():
271+
task = task_dict[task_name]
272+
real_metric = metric # key when looking up the metric with task.aggregation
273+
if metric.endswith(decontaminate_suffix):
274+
real_metric = metric.replace(
275+
decontaminate_suffix, ""
276+
) # decontaminated still uses the same metric
277+
results[task_name][metric] = task.aggregation()[real_metric](items)
278+
279+
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
280+
# so we run them less iterations. still looking for a cleaner way to do this
281+
282+
stderr = lm_eval.metrics.stderr_for_metric(
283+
metric=task.aggregation()[real_metric],
284+
bootstrap_iters=min(bootstrap_iters, 1000)
285+
if metric in ["bleu", "chrf", "ter"]
286+
else bootstrap_iters,
287+
)
288+
289+
if stderr is not None:
290+
results[task_name][metric + "_stderr"] = stderr(items)
291+
292+
return {"results": dict(results), "versions": dict(versions)}
293+
294+
295+
def make_table(result_dict):
296+
"""Generate table of results."""
297+
from pytablewriter import MarkdownTableWriter, LatexTableWriter
298+
299+
md_writer = MarkdownTableWriter()
300+
latex_writer = LatexTableWriter()
301+
md_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"]
302+
latex_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"]
303+
304+
values = []
305+
306+
for k, dic in result_dict["results"].items():
307+
version = result_dict["versions"][k]
308+
for m, v in dic.items():
309+
if m.endswith("_stderr"):
310+
continue
311+
312+
if m + "_stderr" in dic:
313+
se = dic[m + "_stderr"]
314+
values.append([k, version, m, "%.4f" % v, "±", "%.4f" % se])
315+
else:
316+
values.append([k, version, m, "%.4f" % v, "", ""])
317+
k = ""
318+
version = ""
319+
md_writer.value_matrix = values
320+
latex_writer.value_matrix = values
321+
322+
# todo: make latex table look good
323+
# print(latex_writer.dumps())
324+
325+
return md_writer.dumps()

0 commit comments

Comments
 (0)