1+ # Code adapted from https://github.com/bigcode-project/starcoder2/blob/main/finetune.py
2+ import argparse
3+ import multiprocessing
4+ import os
5+ import torch
6+ import transformers
7+ from accelerate import PartialState
8+ from datasets import load_dataset
9+ from peft import LoraConfig
10+ from transformers import (
11+ AutoModelForCausalLM ,
12+ AutoTokenizer ,
13+ BitsAndBytesConfig ,
14+ logging ,
15+ set_seed ,
16+ )
17+ import numpy as np
18+ import random
19+ import warnings
20+ import sys
21+ from trl import SFTTrainer
22+ from trl .trainer import ConstantLengthDataset
23+
24+
25+ def get_args ():
26+ parser = argparse .ArgumentParser ()
27+ parser .add_argument ("--model_id" , type = str , default = "aiXcoder/aixcoder-7b-base" )
28+ parser .add_argument ("--dataset_name" , type = str , default = "the-stack-smol" )
29+ parser .add_argument ("--subset" , type = str , default = "data/rust" )
30+ parser .add_argument ("--split" , type = str , default = "train" )
31+ parser .add_argument ("--fim_rate" , type = float , default = 0.5 )
32+ parser .add_argument ("--dataset_text_field" , type = str , default = "content" )
33+
34+ parser .add_argument ("--max_seq_length" , type = int , default = 1024 )
35+ parser .add_argument ("--max_steps" , type = int , default = 100 )
36+ parser .add_argument ("--micro_batch_size" , type = int , default = 1 )
37+ parser .add_argument ("--gradient_accumulation_steps" , type = int , default = 1 )
38+ parser .add_argument ("--weight_decay" , type = float , default = 0.01 )
39+ parser .add_argument ("--bf16" , type = bool , default = True )
40+
41+ parser .add_argument ("--attention_dropout" , type = float , default = 0.1 )
42+ parser .add_argument ("--learning_rate" , type = float , default = 2e-6 )
43+ parser .add_argument ("--lr_scheduler_type" , type = str , default = "cosine" )
44+ parser .add_argument ("--warmup_steps" , type = int , default = 100 )
45+ parser .add_argument ("--seed" , type = int , default = 0 )
46+ parser .add_argument ("--output_dir" , type = str , default = "finetune_aix_7b" )
47+ parser .add_argument ("--num_proc" , type = int , default = None )
48+ parser .add_argument ("--push_to_hub" , type = bool , default = False )
49+ return parser .parse_args ()
50+
51+
52+ def print_rank_0 (message ):
53+ if torch .distributed .is_initialized ():
54+ if torch .distributed .get_rank () == 0 :
55+ print (message , flush = True , file = sys .stderr )
56+ else :
57+ print (message , flush = True )
58+
59+ def print_trainable_parameters (model ):
60+ """
61+ Prints the number of trainable parameters in the model.
62+ """
63+ trainable_params = 0
64+ all_param = 0
65+ for _ , param in model .named_parameters ():
66+ all_param += param .numel ()
67+ if param .requires_grad :
68+ trainable_params += param .numel ()
69+ print_rank_0 (
70+ f"trainable params: { trainable_params } || all params: { all_param } || trainable%: { 100 * trainable_params / all_param } "
71+ )
72+
73+
74+ class RandomFIMDataset (ConstantLengthDataset ):
75+ """
76+ This class supports the random fill-in-the-middle (FIM) task. If `fim_rate` is greater than 0,
77+ it constructs data in the fill-in-the-middle format with a probability of `fim_rate`.
78+ The aiXcoder-7b-base model uses structured FIM during pre-training,
79+ where a complete code block is constructed as the MIDDLE.
80+ However, creating such training data involves syntactic parsing,
81+ and we currently do not plan to open source the processing code.
82+
83+ """
84+ def __init__ (self , tokenizer , dataset , dataset_text_field = None , fim_rate = 0 , formatting_func = None , infinite = False , seq_length = 1024 , num_of_sequences = 1024 , chars_per_token = 3.6 , eos_token_id = 0 , shuffle = True , append_concat_token = True , add_special_tokens = True ):
85+ self .fim_rate = fim_rate
86+ self .fim_spm_rate = 0.5
87+ self .np_rand = np .random .RandomState (seed = 3574 )
88+ if self .fim_rate > 0 :
89+ print_rank_0 (f"constructing data wit FIM: fim_rate: { self .fim_rate } " )
90+ super ().__init__ (tokenizer , dataset , dataset_text_field , formatting_func , infinite , seq_length , num_of_sequences , chars_per_token , eos_token_id , shuffle , append_concat_token , add_special_tokens )
91+
92+ def __iter__ (self ):
93+ iterator = iter (self .dataset )
94+ more_examples = True
95+ while more_examples :
96+ buffer , buffer_len = [], 0
97+ while True :
98+ if buffer_len >= self .max_buffer_size :
99+ break
100+ try :
101+ if self .fim_rate > 0 :
102+ if self .np_rand .binomial (1 , self .fim_rate ): # sample bernoulli dist
103+
104+ contents = self .formatting_func (next (iterator ))
105+
106+ try :
107+ boundaries = list (self .np_rand .randint (low = 0 , high = len (contents ) + 1 , size = 2 ))
108+ boundaries .sort ()
109+ except ValueError as e :
110+ print (len (contents ), contents )
111+ print (e )
112+ raise e
113+
114+ prefix = contents [:boundaries [0 ]]
115+ middle = contents [boundaries [0 ]:boundaries [1 ]]
116+ suffix = contents [boundaries [1 ]:]
117+ if self .np_rand .binomial (1 , self .fim_spm_rate ):
118+ contents = f"<s>▁<AIX-SPAN-PRE>▁<AIX-SPAN-POST>{ suffix } ▁<AIX-SPAN-MIDDLE>{ prefix } { middle } </s>"
119+ else :
120+ contents = f"<s>▁<AIX-SPAN-PRE>{ prefix } ▁<AIX-SPAN-POST>{ suffix } ▁<AIX-SPAN-MIDDLE>{ middle } </s>"
121+ else :
122+ contents = f"<s>{ self .formatting_func (next (iterator ))} </s>"
123+ else :
124+ contents = f"<s>{ self .formatting_func (next (iterator ))} </s>"
125+
126+ buffer .append (contents )
127+ buffer_len += len (buffer [- 1 ])
128+ except StopIteration :
129+ if self .infinite :
130+ iterator = iter (self .dataset )
131+ warnings .warn ("The dataset reached end and the iterator is reset to the start." )
132+ else :
133+ more_examples = False
134+ break
135+ tokenized_inputs = self .tokenizer (buffer , add_special_tokens = self .add_special_tokens , truncation = False )[
136+ "input_ids"
137+ ]
138+ all_token_ids = []
139+ for tokenized_input in tokenized_inputs :
140+ all_token_ids .extend (tokenized_input )
141+ examples = []
142+ for i in range (0 , len (all_token_ids ), self .seq_length ):
143+ input_ids = all_token_ids [i : i + self .seq_length ]
144+ if len (input_ids ) == self .seq_length :
145+ examples .append (input_ids )
146+ if self .shuffle :
147+ random .shuffle (examples )
148+ for example in examples :
149+ self .current_size += 1
150+ yield {
151+ "input_ids" : torch .LongTensor (example ),
152+ "labels" : torch .LongTensor (example ),
153+ }
154+
155+
156+ def main (args ):
157+ # config
158+ bnb_config = BitsAndBytesConfig (
159+ load_in_4bit = True ,
160+ bnb_4bit_quant_type = "nf4" ,
161+ bnb_4bit_compute_dtype = torch .bfloat16 ,
162+ )
163+ lora_config = LoraConfig (
164+ r = 8 ,
165+ target_modules = [
166+ "q_proj" ,
167+ "o_proj" ,
168+ "k_proj" ,
169+ "v_proj" ,
170+ "gate_proj" ,
171+ "up_proj" ,
172+ "down_proj" ,
173+ ],
174+ task_type = "CAUSAL_LM" ,
175+ )
176+
177+ # load model and dataset
178+ token = os .environ .get ("HF_TOKEN" , None )
179+ model = AutoModelForCausalLM .from_pretrained (
180+ args .model_id ,
181+ quantization_config = bnb_config ,
182+ device_map = {"" : PartialState ().process_index },
183+ attention_dropout = args .attention_dropout ,
184+ attn_implementation = 'flash_attention_2'
185+ )
186+ tokenizer = AutoTokenizer .from_pretrained (args .model_id )
187+ print_trainable_parameters (model )
188+
189+ data = load_dataset (
190+ args .dataset_name ,
191+ data_dir = args .subset ,
192+ split = args .split ,
193+ token = token ,
194+ num_proc = args .num_proc if args .num_proc else multiprocessing .cpu_count (),
195+ )
196+
197+ train_data = RandomFIMDataset (
198+ tokenizer = tokenizer , dataset = data , fim_rate = args .fim_rate , dataset_text_field = args .dataset_text_field ,
199+ infinite = True , seq_length = args .max_seq_length , eos_token_id = tokenizer .eos_token_id
200+ )
201+
202+ # setup the trainer
203+ trainer = SFTTrainer (
204+ model = model ,
205+ train_dataset = train_data ,
206+ max_seq_length = args .max_seq_length ,
207+ args = transformers .TrainingArguments (
208+ per_device_train_batch_size = args .micro_batch_size ,
209+ gradient_accumulation_steps = args .gradient_accumulation_steps ,
210+ warmup_steps = args .warmup_steps ,
211+ max_steps = args .max_steps ,
212+ learning_rate = args .learning_rate ,
213+ lr_scheduler_type = args .lr_scheduler_type ,
214+ weight_decay = args .weight_decay ,
215+ bf16 = args .bf16 ,
216+ logging_strategy = "steps" ,
217+ logging_steps = 10 ,
218+ output_dir = args .output_dir ,
219+ optim = "paged_adamw_8bit" ,
220+ seed = args .seed ,
221+ run_name = f"train-{ args .model_id .split ('/' )[- 1 ]} " ,
222+ report_to = "none" ,
223+ ),
224+ peft_config = lora_config ,
225+ dataset_text_field = args .dataset_text_field ,
226+ )
227+
228+ # launch
229+ print_rank_0 ("Training..." )
230+ trainer .train ()
231+
232+ print_rank_0 ("Saving the last checkpoint of the model" )
233+ model .save_pretrained (os .path .join (args .output_dir , "final_checkpoint/" ))
234+ if args .push_to_hub :
235+ trainer .push_to_hub ("Upload model" )
236+ print_rank_0 ("Training Done! " )
237+
238+
239+ if __name__ == "__main__" :
240+ args = get_args ()
241+ set_seed (args .seed )
242+ os .makedirs (args .output_dir , exist_ok = True )
243+
244+ logging .set_verbosity_error ()
245+
246+ main (args )
0 commit comments