forked from ise-uiuc/magicoder
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtext2code.py
More file actions
150 lines (125 loc) · 4.5 KB
/
text2code.py
File metadata and controls
150 lines (125 loc) · 4.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import itertools
from dataclasses import dataclass
from pathlib import Path
from typing import Literal, TypedDict, cast
from evalplus.data import get_human_eval_plus, get_mbpp_plus, write_jsonl
from tqdm.auto import tqdm
from transformers import HfArgumentParser
from experiments.utils import wget
from magicoder.llm_wrapper import GenerationConfig, get_model_context
from magicoder.prompt_template import MAGICODER_PROMPT
from magicoder.utils import chunked, read_jsonl
class Text2CodeProblem(TypedDict):
id: str
instruction: str
response_prefix: str
MBPP_INSTRUCTION = """{nl_description} Your code should satisfy the following assertion:
```python
{assertions}
```
Enclose your solution in ```python and ```"""
def get_mbpp_raw_problems() -> list[dict]:
problems = get_mbpp_plus()
return list(problems.values())
def get_humaneval_raw_problems() -> list[dict]:
problems = get_human_eval_plus()
return list(problems.values())
def map_mbpp_problem(p: dict) -> Text2CodeProblem:
id = p["task_id"]
prompt = p["prompt"]
start_index = prompt.index('"""')
end_index = prompt.rindex('"""')
prompt = prompt[start_index + 3 : end_index]
assert_index = prompt.index("assert")
instruction = prompt[:assert_index].strip()
if not instruction.endswith("."):
instruction += "."
assertion = prompt[assert_index:].strip()
instruction = f"""{instruction} Your code should satisfy the following assertion:
```python
{assertion}
```"""
response_prefix = f"""```python"""
return Text2CodeProblem(
id=str(id), instruction=instruction, response_prefix=response_prefix
)
def map_humaneval_problem(p: dict) -> Text2CodeProblem:
id = p["task_id"]
prompt = p["prompt"]
prompt = prompt.strip()
# try:
# docstring_index = prompt.index('"""')
# except ValueError:
# docstring_index = prompt.index("'''")
# signature = prompt[:docstring_index].strip()
# Instruction
instruction = f"""Write a solution to the following problem:
```python
{prompt}
```"""
response_prefix = f"""```python
{prompt}"""
return Text2CodeProblem(
id=id, instruction=instruction, response_prefix=response_prefix
)
@dataclass(frozen=True)
class Args:
model_key: str
dataset: Literal["humaneval", "mbpp"]
save_path: str
n_batches: int
n_problems_per_batch: int
n_samples_per_problem: int
# prompted: bool
model_name_or_path: str | None = None
def main():
parser = HfArgumentParser((Args, GenerationConfig))
args, generation_config = cast(
tuple[Args, GenerationConfig],
parser.parse_args_into_dataclasses(),
)
raw_problem_fn, map_problem_fn = (
(get_humaneval_raw_problems, map_humaneval_problem)
if args.dataset == "humaneval"
else (get_mbpp_raw_problems, map_mbpp_problem)
)
raw_problems = raw_problem_fn()
problems = list(map(map_problem_fn, raw_problems))
state = get_model_context(args.model_key, args.model_name_or_path)
problems_chunked = list(chunked(list(problems), args.n_problems_per_batch))
iter = itertools.product(problems_chunked, range(args.n_batches))
n_total = len(problems_chunked) * args.n_batches
Path(args.save_path).write_text("")
for problems, batch_idx in tqdm(iter, total=n_total):
task_ids = [problem["id"] for problem in problems]
prompts = [
# TODO: make it generic for all models
MAGICODER_PROMPT.format(
instruction=problem["instruction"], response=problem["response_prefix"]
)
for problem in problems
]
print("PROMPT")
print(prompts[-1])
all_prompts = prompts * args.n_samples_per_problem
all_task_ids = task_ids * args.n_samples_per_problem
response = state.complete(generation_config, all_prompts)
completions = response.decoded_outputs
assert len(problems) <= args.n_problems_per_batch
assert len(completions) == len(problems) * args.n_samples_per_problem
print("COMPLETION")
print(completions[-1])
samples = [
dict(
task_id=task_id,
completion=completion[
: index
if (index := completion.find("```")) != -1
else len(completion)
],
)
for task_id, completion in zip(all_task_ids, completions)
]
write_jsonl(args.save_path, samples, append=True)
if __name__ == "__main__":
main()