Skip to content

Commit 11d4540

Browse files
committed
Enable quantized net loading
1 parent 5ded834 commit 11d4540

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

trainer/quantize.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import struct
33
import numpy as np
44

5+
from model import PerspectiveNetwork
6+
57
QA = 403
68
QB = 64
79
QAB = QA * QB
@@ -31,3 +33,20 @@ def quantize(model, bin_path):
3133
bin_file.write(struct.pack('<' + 'h' * len(feature_bias_values), *feature_bias_values))
3234
bin_file.write(struct.pack('<' + 'h' * len(output_weight_values), *output_weight_values))
3335
bin_file.write(struct.pack('<' + 'h' * len(output_bias_values), *output_bias_values))
36+
37+
38+
def load_quantized_net(bin_path, hl_size, qa, qb):
39+
with open(bin_path, "rb") as bin_file:
40+
# Read feature weights
41+
feature_weights = struct.unpack(f'<{768 * hl_size}h', bin_file.read(768 * hl_size * 2))
42+
feature_bias = struct.unpack(f'<{hl_size}h', bin_file.read(hl_size * 2))
43+
output_weights = struct.unpack(f'<{2 * hl_size}h', bin_file.read(2 * hl_size * 2))
44+
output_bias = struct.unpack('<1h', bin_file.read(1 * 2))
45+
46+
model = PerspectiveNetwork(hl_size)
47+
model.feature_transformer.weight.data = torch.tensor(np.array(feature_weights).reshape(768, hl_size).T / qa, dtype=torch.float32)
48+
model.feature_transformer.bias.data = torch.tensor(np.array(feature_bias) / qa, dtype=torch.float32)
49+
model.output_layer.weight.data = torch.tensor(np.array(output_weights).reshape(1, 2 * hl_size) / qb, dtype=torch.float32)
50+
model.output_layer.bias.data = torch.tensor(np.array(output_bias) / (qa * qb), dtype=torch.float32)
51+
52+
return model

0 commit comments

Comments
 (0)