-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathtrain.py
More file actions
106 lines (78 loc) · 3.4 KB
/
train.py
File metadata and controls
106 lines (78 loc) · 3.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
import torch
import torch.nn.functional as F
from bitdelta.diff import compress_diff, save_diff, save_full_model
from bitdelta.misc import find_corr_stddev
from bitdelta.utils import get_model, parse_args, get_tokenizer
from tqdm import tqdm
from bitdelta.data import get_dataset, get_dataloader
import json
args = parse_args()
# create save_dir if it doesn't exist
os.makedirs(args.save_dir, exist_ok=True)
tokenizer = get_tokenizer(args.base_model)
with torch.no_grad():
base_model = get_model(args.base_model, args.base_model_device, args.base_model_memory_map)
finetuned_model = get_model(args.finetuned_model, args.finetuned_model_device, args.finetuned_model_memory_map)
# get corr/stddev stats
if args.debug:
print(f"finding corr/stddev stats...")
corrs, stddevs = find_corr_stddev(base_model, finetuned_model)
corr = sum(corrs) / len(corrs)
stddev = sum(stddevs) / len(stddevs)
# save in args.save_dir as csv
with open(os.path.join(args.save_dir, "corr_stddev.csv"), "w") as f:
f.write(f"corr,stddev\n{corr},{stddev}")
finetuned_compressed_model = get_model(args.finetuned_model, args.finetuned_compressed_model_device, args.finetuned_compressed_model_memory_map)
print(f"compressing diff...")
compress_diff(base_model, finetuned_model, finetuned_compressed_model)
train_num_samples = args.batch_size * args.num_steps
train_dataset = get_dataset(
args.dataset_name,
args.subset,
"train",
size=train_num_samples,
)
train_dataloader = get_dataloader(
train_dataset,
tokenizer,
args.batch_size,
num_workers=4,
max_length=args.max_length,
)
# save untrained delta
save_diff(finetuned_compressed_model, os.path.join(args.save_dir, "diff_untrained.pt"))
optimizer = torch.optim.AdamW(finetuned_compressed_model.parameters(), lr=args.lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.num_steps)
bar = tqdm(train_dataloader)
train_loss_list = []
# Train loop
for step, batch in enumerate(bar):
batch1 = {k: v.to(finetuned_model.device) for k, v in batch.items()}
with torch.inference_mode():
finetuned_outputs = finetuned_model(**batch1)
batch2 = {k: v.to(finetuned_compressed_model.device) for k, v in batch.items()}
finetuned_compressed_outputs = finetuned_compressed_model(**batch2)
loss = F.mse_loss(
finetuned_outputs.logits.clone().to(finetuned_compressed_outputs.logits.device),
finetuned_compressed_outputs.logits,
)
train_loss_list.append(loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
bar.set_description(f"train loss: {loss.item()}")
# save loss list
if args.debug:
with open(os.path.join(args.save_dir, f"train_loss_{args.num_groups}.json"), "w") as f:
json.dump(train_loss_list, f)
# save trained delta
save_diff(finetuned_compressed_model, os.path.join(args.save_dir, "diff.pt"))
del base_model, finetuned_model, finetuned_compressed_model
torch.cuda.empty_cache()
if args.save_full_model:
print("saving uncalibrated model")
save_full_model(args.base_model, args.finetuned_model, os.path.join(args.save_dir, "diff_untrained.pt"), os.path.join(args.save_dir, "uncalibrated_model"), device="cpu")
print("saving calibrated model")
save_full_model(args.base_model, args.finetuned_model, os.path.join(args.save_dir, "diff.pt"), os.path.join(args.save_dir, "calibrated_model"), device="cpu")