Skip to content

Commit 1308a12

Browse files
committed
Add support for loading quantized weights
1 parent 7080875 commit 1308a12

File tree

1 file changed

+33
-24
lines changed

1 file changed

+33
-24
lines changed

trainer/quantize.py

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,27 +7,36 @@
77
QAB = QA * QB
88

99
def quantize(model, bin_path):
10-
# Extract weights from the model
11-
feature_weights = model.feature_transformer.weight.detach().cpu().numpy()
12-
feature_bias = model.feature_transformer.bias.detach().cpu().numpy()
13-
output_weights = model.output_layer.weight.detach().cpu().numpy()
14-
output_bias = model.output_layer.bias.detach().cpu().numpy()
15-
16-
# Quantized weights
17-
feature_weight_quantized = (feature_weights * QA).T.astype(np.int16)
18-
feature_bias_quantized = (feature_bias * QA).astype(np.int16)
19-
output_weight_quantized = (output_weights * QB).astype(np.int16)
20-
output_bias_quantized = (output_bias * QAB).astype(np.int16)
21-
22-
# Flatten to 1D lists
23-
feature_weight_values = feature_weight_quantized.flatten().tolist()
24-
feature_bias_values = feature_bias_quantized.flatten().tolist()
25-
output_weight_values = output_weight_quantized.flatten().tolist()
26-
output_bias_values = output_bias_quantized.flatten().tolist()
27-
28-
# Save to binary file
29-
with open(bin_path, "wb") as bin_file:
30-
bin_file.write(struct.pack('<' + 'h' * len(feature_weight_values), *feature_weight_values))
31-
bin_file.write(struct.pack('<' + 'h' * len(feature_bias_values), *feature_bias_values))
32-
bin_file.write(struct.pack('<' + 'h' * len(output_weight_values), *output_weight_values))
33-
bin_file.write(struct.pack('<' + 'h' * len(output_bias_values), *output_bias_values))
10+
def quant(arr, scale):
11+
return (arr * scale).astype(np.int16).flatten()
12+
13+
f_w = model.feature_transformer.weight.detach().cpu().numpy()
14+
f_b = model.feature_transformer.bias.detach().cpu().numpy()
15+
o_w = model.output_layer.weight.detach().cpu().numpy()
16+
o_b = model.output_layer.bias.detach().cpu().numpy()
17+
18+
quantized_data = np.concatenate([
19+
quant(f_w.T, QA),
20+
quant(f_b, QA),
21+
quant(o_w, QB),
22+
quant(o_b, QAB)
23+
])
24+
25+
quantized_data.tofile(bin_path)
26+
27+
28+
def load_quantized_net(bin_path, hl_size, qa, qb):
29+
with open(bin_path, "rb") as bin_file:
30+
# Read feature weights
31+
feature_weights = struct.unpack(f'<{768 * hl_size}h', bin_file.read(768 * hl_size * 2))
32+
feature_bias = struct.unpack(f'<{hl_size}h', bin_file.read(hl_size * 2))
33+
output_weights = struct.unpack(f'<{2 * hl_size}h', bin_file.read(2 * hl_size * 2))
34+
output_bias = struct.unpack('<1h', bin_file.read(1 * 2))
35+
36+
model = PerspectiveNetwork(hl_size)
37+
model.feature_transformer.weight.data = torch.tensor(np.array(feature_weights).reshape(768, hl_size).T / qa, dtype=torch.float32)
38+
model.feature_transformer.bias.data = torch.tensor(np.array(feature_bias) / qa, dtype=torch.float32)
39+
model.output_layer.weight.data = torch.tensor(np.array(output_weights).reshape(1, 2 * hl_size) / qb, dtype=torch.float32)
40+
model.output_layer.bias.data = torch.tensor(np.array(output_bias) / (qa * qb), dtype=torch.float32)
41+
42+
return model

0 commit comments

Comments
 (0)