Skip to content

Commit fc11482

Browse files
committed
add supports for fine-tuning
1 parent 44995b8 commit fc11482

File tree

5 files changed

+305
-1
lines changed

5 files changed

+305
-1
lines changed

README.md

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ Table of Contents
1414
- [Model Weights](#model-weights)
1515
- [Inference Example](#inference-example)
1616
- [Quantized through bitsandbytes](#quantized-through-bitsandbytes)
17+
- [Fine-tuning example](#fine-tuning-example)
1718
3. [Data for aiXcoder 7B](#data-for-aixcoder-7b)
1819
4. [Training](#training)
1920
- [Training Hyperparameters](#training-hyperparameters)
@@ -298,6 +299,30 @@ load_in_8bit=True:
298299

299300
```
300301

302+
### Fine-tuning example
303+
304+
If you want to fine-tune on your own code, you can quickly get started with training using Huggingface's PEFT tools. Before doing so, you need to install the necessary libraries with `pip install -r requirements_peft.txt`.
305+
306+
Then, execute the training command:
307+
308+
```bash
309+
accelerate launch finetune.py \
310+
--model_id "aiXcoder/aixcoder-7b-base" \
311+
--dataset_name "bigcode/the-stack-smol" \
312+
--subset "data/rust" \
313+
--dataset_text_field "content" \
314+
--split "train" \
315+
--max_seq_length 1024 \
316+
--max_steps 10000 \
317+
--micro_batch_size 1 \
318+
--gradient_accumulation_steps 8 \
319+
--learning_rate 5e-6 \
320+
--warmup_steps 20 \
321+
--fim_rate 0.5 \
322+
--num_proc "$(nproc)"
323+
```
324+
325+
In the fine-tuning script, we have constructed a simple random FIM (Fill-In-the-Middle) training task that can train the model on the completion and generation capabilities on your own data. It should be noted that the aiXcoder-7b-base uses [structured FIM](#pre-training-tasks) during pre-training, which involves constructing a complete code block as the MIDDLE. However, creating such training data involves syntactic parsing, which may require developers to implement themselves.
301326

302327
## Data for aiXcoder 7B
303328

README_CN.md

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
- [模型权重](#模型权重)
1515
- [推理示例](#推理示例)
1616
- [Bitsandbytes 量化执行](#bitsandbytes-量化执行)
17+
- [微调示例](#微调示例)
1718
3. [aiXcoder 7B 训练数据](#aixcoder-7b-训练数据)
1819
4. [训练](#训练)
1920
- [训练超参数](#训练超参数)
@@ -291,6 +292,31 @@ load_in_8bit=True:
291292

292293
```
293294

295+
### 微调示例
296+
297+
如果希望针对自有代码进行微调,可以借助 Huggingface 的 PEFT 工具快速上手训练。在此之前你需要先安装依赖库 `pip install -r requirements_peft.txt`
298+
299+
然后执行训练命令:
300+
301+
```bash
302+
accelerate launch finetune.py \
303+
--model_id "aiXcoder/aixcoder-7b-base" \
304+
--dataset_name "bigcode/the-stack-smol" \
305+
--subset "data/rust" \
306+
--dataset_text_field "content" \
307+
--split "train" \
308+
--max_seq_length 1024 \
309+
--max_steps 10000 \
310+
--micro_batch_size 1 \
311+
--gradient_accumulation_steps 8 \
312+
--learning_rate 5e-6 \
313+
--warmup_steps 20 \
314+
--fim_rate 0.5 \
315+
--num_proc "$(nproc)"
316+
```
317+
318+
在微调脚本中,我们构造了简单的随机 FIM 训练任务,可以训练模型在自有数据上的补全与生成能力。需要注意的是,aiXcoder-7b-base 在预训练中采用的是[结构化 FIM](#预训练任务),即将一个完整代码块构造成 MIDDLE,不过构造这样的训练数据涉及到语法解析,可能需要开发者自行实现。
319+
294320
## aiXcoder 7B 训练数据
295321

296322
aiXcoder 的数据分为核心数据集与扩展数据集,核心数据集由业务上常用的几大编程语言,以及与代码息息相关的自然语言组成。核心数据集的编程语言主要有 C++、Python、Java、JavaScript等近百种主流编程语言,自然语言上主要由 StackOverFlow 问答、技术博客、代码文档、计算机领域论文等组成。扩展数据集主要由过滤后的代码开源数据集,英文自然语言高质量数据集,中文自然语言高质量数据集组成。

finetune.py

Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
# Code adapted from https://github.com/bigcode-project/starcoder2/blob/main/finetune.py
2+
import argparse
3+
import multiprocessing
4+
import os
5+
import torch
6+
import transformers
7+
from accelerate import PartialState
8+
from datasets import load_dataset
9+
from peft import LoraConfig
10+
from transformers import (
11+
AutoModelForCausalLM,
12+
AutoTokenizer,
13+
BitsAndBytesConfig,
14+
logging,
15+
set_seed,
16+
)
17+
import numpy as np
18+
import random
19+
import warnings
20+
import sys
21+
from trl import SFTTrainer
22+
from trl.trainer import ConstantLengthDataset
23+
24+
25+
def get_args():
26+
parser = argparse.ArgumentParser()
27+
parser.add_argument("--model_id", type=str, default="aiXcoder/aixcoder-7b-base")
28+
parser.add_argument("--dataset_name", type=str, default="the-stack-smol")
29+
parser.add_argument("--subset", type=str, default="data/rust")
30+
parser.add_argument("--split", type=str, default="train")
31+
parser.add_argument("--fim_rate", type=float, default=0.5)
32+
parser.add_argument("--dataset_text_field", type=str, default="content")
33+
34+
parser.add_argument("--max_seq_length", type=int, default=1024)
35+
parser.add_argument("--max_steps", type=int, default=100)
36+
parser.add_argument("--micro_batch_size", type=int, default=1)
37+
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
38+
parser.add_argument("--weight_decay", type=float, default=0.01)
39+
parser.add_argument("--bf16", type=bool, default=True)
40+
41+
parser.add_argument("--attention_dropout", type=float, default=0.1)
42+
parser.add_argument("--learning_rate", type=float, default=2e-6)
43+
parser.add_argument("--lr_scheduler_type", type=str, default="cosine")
44+
parser.add_argument("--warmup_steps", type=int, default=100)
45+
parser.add_argument("--seed", type=int, default=0)
46+
parser.add_argument("--output_dir", type=str, default="finetune_aix_7b")
47+
parser.add_argument("--num_proc", type=int, default=None)
48+
parser.add_argument("--push_to_hub", type=bool, default=False)
49+
return parser.parse_args()
50+
51+
52+
def print_rank_0(message):
53+
if torch.distributed.is_initialized():
54+
if torch.distributed.get_rank() == 0:
55+
print(message, flush=True, file=sys.stderr)
56+
else:
57+
print(message, flush=True)
58+
59+
def print_trainable_parameters(model):
60+
"""
61+
Prints the number of trainable parameters in the model.
62+
"""
63+
trainable_params = 0
64+
all_param = 0
65+
for _, param in model.named_parameters():
66+
all_param += param.numel()
67+
if param.requires_grad:
68+
trainable_params += param.numel()
69+
print_rank_0(
70+
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
71+
)
72+
73+
74+
class RandomFIMDataset(ConstantLengthDataset):
75+
"""
76+
This class supports the random fill-in-the-middle (FIM) task. If `fim_rate` is greater than 0,
77+
it constructs data in the fill-in-the-middle format with a probability of `fim_rate`.
78+
The aiXcoder-7b-base model uses structured FIM during pre-training,
79+
where a complete code block is constructed as the MIDDLE.
80+
However, creating such training data involves syntactic parsing,
81+
and we currently do not plan to open source the processing code.
82+
83+
"""
84+
def __init__(self, tokenizer, dataset, dataset_text_field=None, fim_rate=0, formatting_func=None, infinite=False, seq_length=1024, num_of_sequences=1024, chars_per_token=3.6, eos_token_id=0, shuffle=True, append_concat_token=True, add_special_tokens=True):
85+
self.fim_rate = fim_rate
86+
self.fim_spm_rate = 0.5
87+
self.np_rand = np.random.RandomState(seed=3574)
88+
if self.fim_rate > 0:
89+
print_rank_0(f"constructing data wit FIM: fim_rate: {self.fim_rate}")
90+
super().__init__(tokenizer, dataset, dataset_text_field, formatting_func, infinite, seq_length, num_of_sequences, chars_per_token, eos_token_id, shuffle, append_concat_token, add_special_tokens)
91+
92+
def __iter__(self):
93+
iterator = iter(self.dataset)
94+
more_examples = True
95+
while more_examples:
96+
buffer, buffer_len = [], 0
97+
while True:
98+
if buffer_len >= self.max_buffer_size:
99+
break
100+
try:
101+
if self.fim_rate > 0:
102+
if self.np_rand.binomial(1, self.fim_rate): # sample bernoulli dist
103+
104+
contents = self.formatting_func(next(iterator))
105+
106+
try:
107+
boundaries = list(self.np_rand.randint(low=0, high=len(contents) + 1, size=2))
108+
boundaries.sort()
109+
except ValueError as e:
110+
print(len(contents), contents)
111+
print(e)
112+
raise e
113+
114+
prefix = contents[:boundaries[0]]
115+
middle = contents[boundaries[0]:boundaries[1]]
116+
suffix = contents[boundaries[1]:]
117+
if self.np_rand.binomial(1, self.fim_spm_rate):
118+
contents = f"<s>▁<AIX-SPAN-PRE>▁<AIX-SPAN-POST>{suffix}▁<AIX-SPAN-MIDDLE>{prefix}{middle}</s>"
119+
else:
120+
contents = f"<s>▁<AIX-SPAN-PRE>{prefix}▁<AIX-SPAN-POST>{suffix}▁<AIX-SPAN-MIDDLE>{middle}</s>"
121+
else:
122+
contents = f"<s>{self.formatting_func(next(iterator))}</s>"
123+
else:
124+
contents = f"<s>{self.formatting_func(next(iterator))}</s>"
125+
126+
buffer.append(contents)
127+
buffer_len += len(buffer[-1])
128+
except StopIteration:
129+
if self.infinite:
130+
iterator = iter(self.dataset)
131+
warnings.warn("The dataset reached end and the iterator is reset to the start.")
132+
else:
133+
more_examples = False
134+
break
135+
tokenized_inputs = self.tokenizer(buffer, add_special_tokens=self.add_special_tokens, truncation=False)[
136+
"input_ids"
137+
]
138+
all_token_ids = []
139+
for tokenized_input in tokenized_inputs:
140+
all_token_ids.extend(tokenized_input)
141+
examples = []
142+
for i in range(0, len(all_token_ids), self.seq_length):
143+
input_ids = all_token_ids[i : i + self.seq_length]
144+
if len(input_ids) == self.seq_length:
145+
examples.append(input_ids)
146+
if self.shuffle:
147+
random.shuffle(examples)
148+
for example in examples:
149+
self.current_size += 1
150+
yield {
151+
"input_ids": torch.LongTensor(example),
152+
"labels": torch.LongTensor(example),
153+
}
154+
155+
156+
def main(args):
157+
# config
158+
bnb_config = BitsAndBytesConfig(
159+
load_in_4bit=True,
160+
bnb_4bit_quant_type="nf4",
161+
bnb_4bit_compute_dtype=torch.bfloat16,
162+
)
163+
lora_config = LoraConfig(
164+
r=8,
165+
target_modules=[
166+
"q_proj",
167+
"o_proj",
168+
"k_proj",
169+
"v_proj",
170+
"gate_proj",
171+
"up_proj",
172+
"down_proj",
173+
],
174+
task_type="CAUSAL_LM",
175+
)
176+
177+
# load model and dataset
178+
token = os.environ.get("HF_TOKEN", None)
179+
model = AutoModelForCausalLM.from_pretrained(
180+
args.model_id,
181+
quantization_config=bnb_config,
182+
device_map={"": PartialState().process_index},
183+
attention_dropout=args.attention_dropout,
184+
attn_implementation='flash_attention_2'
185+
)
186+
tokenizer = AutoTokenizer.from_pretrained(args.model_id)
187+
print_trainable_parameters(model)
188+
189+
data = load_dataset(
190+
args.dataset_name,
191+
data_dir=args.subset,
192+
split=args.split,
193+
token=token,
194+
num_proc=args.num_proc if args.num_proc else multiprocessing.cpu_count(),
195+
)
196+
197+
train_data = RandomFIMDataset(
198+
tokenizer=tokenizer, dataset=data, fim_rate=args.fim_rate, dataset_text_field=args.dataset_text_field,
199+
infinite=True, seq_length=args.max_seq_length, eos_token_id=tokenizer.eos_token_id
200+
)
201+
202+
# setup the trainer
203+
trainer = SFTTrainer(
204+
model=model,
205+
train_dataset=train_data,
206+
max_seq_length=args.max_seq_length,
207+
args=transformers.TrainingArguments(
208+
per_device_train_batch_size=args.micro_batch_size,
209+
gradient_accumulation_steps=args.gradient_accumulation_steps,
210+
warmup_steps=args.warmup_steps,
211+
max_steps=args.max_steps,
212+
learning_rate=args.learning_rate,
213+
lr_scheduler_type=args.lr_scheduler_type,
214+
weight_decay=args.weight_decay,
215+
bf16=args.bf16,
216+
logging_strategy="steps",
217+
logging_steps=10,
218+
output_dir=args.output_dir,
219+
optim="paged_adamw_8bit",
220+
seed=args.seed,
221+
run_name=f"train-{args.model_id.split('/')[-1]}",
222+
report_to="none",
223+
),
224+
peft_config=lora_config,
225+
dataset_text_field=args.dataset_text_field,
226+
)
227+
228+
# launch
229+
print_rank_0("Training...")
230+
trainer.train()
231+
232+
print_rank_0("Saving the last checkpoint of the model")
233+
model.save_pretrained(os.path.join(args.output_dir, "final_checkpoint/"))
234+
if args.push_to_hub:
235+
trainer.push_to_hub("Upload model")
236+
print_rank_0("Training Done! ")
237+
238+
239+
if __name__ == "__main__":
240+
args = get_args()
241+
set_seed(args.seed)
242+
os.makedirs(args.output_dir, exist_ok=True)
243+
244+
logging.set_verbosity_error()
245+
246+
main(args)

megatron_mini/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1259,7 +1259,7 @@ def encode(self, code_string: str, later_code: str, file_path: str) -> List[int]
12591259
t = [self.bos_id] + t
12601260
else:
12611261
t = [self.bos_id, self.prefix_tok_id, self.suffix_tok_id] + self.__encode(later_code, None, True)
1262-
t = [self.middle_tok_id] + self.__encode(code_string, file_path, False)
1262+
t += [self.middle_tok_id] + self.__encode(code_string, file_path, False)
12631263

12641264
return t
12651265

requirements_peft.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
accelerate==0.27.1
2+
datasets>=2.16.1
3+
bitsandbytes==0.41.3
4+
peft==0.8.2
5+
trl==0.7.10
6+
wandb==0.16.3
7+
huggingface_hub==0.20.3

0 commit comments

Comments
 (0)