-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathdiff.py
More file actions
116 lines (91 loc) · 4.02 KB
/
diff.py
File metadata and controls
116 lines (91 loc) · 4.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import torch
import torch.nn as nn
import gc
from bitdelta.binary_gemm_kernel import pack, unpack, binary_bmm
from bitdelta.utils import get_model, get_tokenizer
class BinaryDiff(nn.Module):
def __init__(self, base, finetune):
super().__init__()
diff = finetune - base
quantile = diff.float().abs().mean()
mask = torch.ones_like(diff)
mask[diff < 0] = 0
mask = pack(mask.bool().T)
self.register_buffer("mask", mask)
self.register_buffer("base", base.T)
self.register_parameter(
"coeff",
nn.Parameter(
torch.tensor(
quantile,
dtype=torch.float32,
requires_grad=True,
device=base.device,
)
),
)
del base, finetune, diff
def forward(self, x):
# print(x.shape, self.base.shape, self.coeff.shape, self.mask.shape)
# [B, seq, in] @ [in, out] + [B, seq, in] @ [B, in/32, out]
# TODO: This can be faster
repeated_mask = self.mask.unsqueeze(0).repeat(x.size(0), 1, 1)
return x @ self.base + self.coeff * binary_bmm(x, repeated_mask)
def compress_diff(base_model, finetuned_model, finetuned_compressed_model):
def compress_submodule(name, subname, module, submodule):
target_device = submodule.weight.device
base_weight = base_model.get_submodule(f"{name}.{subname}").weight.detach().to(target_device)
finetuned_weight = finetuned_model.get_submodule(f"{name}.{subname}").weight.detach().to(target_device)
compressed = BinaryDiff(
base=base_weight,
finetune=finetuned_weight,
).to(target_device)
del submodule, base_weight
setattr(module, subname, None)
gc.collect()
torch.cuda.empty_cache()
setattr(module, subname, compressed)
# TODO: this can be parallelized
for name, module in finetuned_compressed_model.named_modules():
if "mlp" in name or "self_attn" in name:
for subname, submodule in module.named_children():
if "proj" in subname:
compress_submodule(name, subname, module, submodule)
def save_diff(finetuned_compressed_model, save_dir):
diff_dict = {}
for name, module in finetuned_compressed_model.named_modules():
if isinstance(module, BinaryDiff):
# diff_dict[name + ".mask"] = (module.mask == 1).bool().cpu()
diff_dict[name + ".mask"] = module.mask.cpu()
diff_dict[name + ".coeff"] = module.coeff.cpu()
for name, param in finetuned_compressed_model.named_parameters():
if param.requires_grad:
diff_dict[name] = param.cpu()
torch.save(diff_dict, save_dir)
@torch.no_grad()
def load_diff(model, diff_dir):
device = model.device
diff_dict = torch.load(diff_dir)
for name, module in model.named_modules():
if name + ".mask" in diff_dict:
coeff = diff_dict[name + ".coeff"].to(device)
mask = diff_dict[name + ".mask"].to(device)
# setattr(module, "mask", mask)
# setattr(module, "coeff", coeff)
weight = (unpack(mask)*2-1) * coeff
module.weight.add_(weight.T.to(module.weight.dtype))
elif name + ".weight" in diff_dict:
module.weight = nn.Parameter(diff_dict[name + ".weight"].to(device).to(module.weight.dtype))
elif name + '.A' in diff_dict:
A = diff_dict[name + '.A'].to(device)
B = diff_dict[name + '.B'].to(device)
mask = (A @ B).T
module.weight.add_(mask.to(module.weight.dtype))
model.config.vocab_size = model.lm_head.weight.size(0)
def save_full_model(base_model_name, finetuned_model_name, diff_dir, save_dir, device):
base_model = get_model(base_model_name, device)
tokenizer = get_tokenizer(finetuned_model_name)
load_diff(base_model, diff_dir)
base_model.save_pretrained(save_dir)
tokenizer.save_pretrained(save_dir)
del base_model