Skip to content

Commit d4a5bb2

Browse files
authored
Decouple quantization from model and serialize (official-stockfish#357)
* Decouple quantization from model and serialize * fix type checking circular dependency * fix more errors * ruff format * try fix * revert no_grad changes * revert model.eval change * fix tensor type * clarify name
1 parent d189a14 commit d4a5bb2

File tree

12 files changed

+283
-119
lines changed

12 files changed

+283
-119
lines changed

cross_check_eval.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,18 @@
1212
NNUE,
1313
NNUEReader,
1414
ModelConfig,
15+
QuantizationConfig,
1516
)
1617

1718

18-
def read_model(nnue_path, feature_set: FeatureSet, config: ModelConfig):
19+
def read_model(
20+
nnue_path,
21+
feature_set: FeatureSet,
22+
config: ModelConfig,
23+
quantize_config: QuantizationConfig,
24+
):
1925
with open(nnue_path, "rb") as f:
20-
reader = NNUEReader(f, feature_set, config)
26+
reader = NNUEReader(f, feature_set, config, quantize_config)
2127
return reader.model
2228

2329

@@ -177,10 +183,15 @@ def main():
177183
feature_set = get_feature_set_from_name(args.features)
178184
if args.checkpoint:
179185
model = NNUE.load_from_checkpoint(
180-
args.checkpoint, feature_set=feature_set, config=ModelConfig(L1=args.l1)
186+
args.checkpoint,
187+
feature_set=feature_set,
188+
config=ModelConfig(L1=args.l1),
189+
quantize_config=QuantizationConfig(),
181190
)
182191
else:
183-
model = read_model(args.net, feature_set, ModelConfig(L1=args.l1))
192+
model = read_model(
193+
args.net, feature_set, ModelConfig(L1=args.l1), QuantizationConfig()
194+
)
184195
model.eval()
185196
fen_batch_provider = make_fen_batch_provider(args.data, batch_size)
186197

ftperm.py

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,14 @@
4141

4242
import data_loader
4343
import model as M
44-
from model import FeatureSet, NNUE, NNUEModel, NNUEReader, ModelConfig
44+
from model import (
45+
FeatureSet,
46+
NNUE,
47+
NNUEModel,
48+
NNUEReader,
49+
ModelConfig,
50+
QuantizationConfig,
51+
)
4552

4653

4754
"""
@@ -391,9 +398,14 @@ def find_perm_impl(actmat, use_cupy, L1: int):
391398
# -------------------------------------------------------------
392399

393400

394-
def read_model(nnue_path, feature_set: FeatureSet, config: ModelConfig):
401+
def read_model(
402+
nnue_path,
403+
feature_set: FeatureSet,
404+
config: ModelConfig,
405+
quantize_config: QuantizationConfig,
406+
):
395407
with open(nnue_path, "rb") as f:
396-
reader = NNUEReader(f, feature_set, config)
408+
reader = NNUEReader(f, feature_set, config, quantize_config)
397409
return reader.model
398410

399411

@@ -419,9 +431,13 @@ def filter_fens(fens):
419431
return filtered_fens
420432

421433

422-
def quantize_ft(model):
423-
model.input.weight.data = model.input.weight.data.mul(model.quantized_one).round()
424-
model.input.bias.data = model.input.bias.data.mul(model.quantized_one).round()
434+
def quantize_ft(model: NNUEModel):
435+
model.input.weight.data = model.input.weight.data.mul(
436+
model.quantization.quantized_one
437+
).round()
438+
model.input.bias.data = model.input.bias.data.mul(
439+
model.quantization.quantized_one
440+
).round()
425441

426442

427443
def forward_ft(
@@ -508,7 +524,7 @@ def ft_permute(model, ft_perm_path):
508524
ft_permute_impl(model, permutation)
509525

510526

511-
def gather_impl(model, dataset, count):
527+
def gather_impl(model: NNUEModel, dataset, count):
512528
ZERO_POINT = 0.0 # Vary this to check hypothetical forced larger truncation to zero
513529
BATCH_SIZE = 1000
514530

@@ -546,11 +562,17 @@ def gather_impl(model, dataset, count):
546562
def command_gather(args):
547563
feature_set = M.get_feature_set_from_name(args.features)
548564
if args.checkpoint:
549-
model = NNUE.load_from_checkpoint(
550-
args.checkpoint, feature_set=feature_set, config=ModelConfig(L1=args.l1)
565+
nnue = NNUE.load_from_checkpoint(
566+
args.checkpoint,
567+
feature_set=feature_set,
568+
config=ModelConfig(L1=args.l1),
569+
quantize_config=QuantizationConfig(),
551570
)
571+
model = nnue.model
552572
else:
553-
model = read_model(args.net, feature_set, ModelConfig(L1=args.l1))
573+
model = read_model(
574+
args.net, feature_set, ModelConfig(L1=args.l1), QuantizationConfig()
575+
)
554576

555577
model.eval()
556578

@@ -595,7 +617,7 @@ def command_find_perm(args):
595617
with open(args.data, "rb") as file:
596618
actmat = np.load(file)
597619

598-
perm = find_perm_impl(actmat, args.use_cupy)
620+
perm = find_perm_impl(actmat, args.use_cupy, args.l1)
599621

600622
# perm = np.random.permutation([i for i in range(L1)])
601623
with open(args.out, "wb") as file:
@@ -618,7 +640,7 @@ def ft_optimize(
618640

619641
print("Finding permutation...")
620642
perm = find_perm_impl(actmat, use_cupy, model.L1)
621-
if actmat_save_path is not None:
643+
if perm_save_path is not None:
622644
with open(perm_save_path, "wb") as file:
623645
np.save(file, perm)
624646

model/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .features import add_feature_args, FeatureSet, get_feature_set_from_name
44
from .lightning_module import NNUE
55
from .model import NNUEModel
6+
from .quantize import QuantizationConfig
67
from .utils import coalesce_ft_weights, load_model, NNUEReader, NNUEWriter
78

89

@@ -15,6 +16,7 @@
1516
"get_feature_set_from_name",
1617
"NNUE",
1718
"NNUEModel",
19+
"QuantizationConfig",
1820
"coalesce_ft_weights",
1921
"load_model",
2022
"NNUEReader",

model/lightning_module.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .config import LossParams, ModelConfig
77
from .features import FeatureSet
88
from .model import NNUEModel
9+
from .quantize import QuantizationConfig
910

1011

1112
def _get_parameters(layers: list[nn.Module]):
@@ -29,6 +30,7 @@ def __init__(
2930
self,
3031
feature_set: FeatureSet,
3132
config: ModelConfig,
33+
quantize_config: QuantizationConfig,
3234
max_epoch=800,
3335
num_batches_per_epoch=int(100_000_000 / 16384),
3436
gamma=0.992,
@@ -40,7 +42,7 @@ def __init__(
4042
):
4143
super().__init__()
4244
self.model: NNUEModel = NNUEModel(
43-
feature_set, config, num_psqt_buckets, num_ls_buckets
45+
feature_set, config, quantize_config, num_psqt_buckets, num_ls_buckets
4446
)
4547
self.loss_params = loss_params
4648
self.max_epoch = max_epoch
@@ -79,7 +81,7 @@ def step_(self, batch: tuple[Tensor, ...], batch_idx, loss_type):
7981
psqt_indices,
8082
layer_stack_indices,
8183
)
82-
* self.model.nnue2score
84+
* self.model.quantization.nnue2score
8385
)
8486

8587
p = self.loss_params

model/model.py

Lines changed: 5 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .config import ModelConfig
77
from .feature_transformer import DoubleFeatureTransformerSlice
88
from .features import FeatureSet
9+
from .quantize import QuantizationConfig, QuantizationManager
910

1011

1112
class LayerStacks(nn.Module):
@@ -128,6 +129,7 @@ def __init__(
128129
self,
129130
feature_set: FeatureSet,
130131
config: ModelConfig,
132+
quantize_config: QuantizationConfig,
131133
num_psqt_buckets: int = 8,
132134
num_ls_buckets: int = 8,
133135
):
@@ -146,33 +148,8 @@ def __init__(
146148
self.feature_set = feature_set
147149
self.layer_stacks = LayerStacks(self.num_ls_buckets, config)
148150

149-
self.nnue2score = 600.0
150-
self.weight_scale_hidden = 64.0
151-
self.weight_scale_out = 16.0
152-
self.quantized_one = 127.0
153-
154-
max_hidden_weight = self.quantized_one / self.weight_scale_hidden
155-
max_out_weight = (self.quantized_one * self.quantized_one) / (
156-
self.nnue2score * self.weight_scale_out
157-
)
158-
self.weight_clipping = [
159-
{
160-
"params": [self.layer_stacks.l1.weight],
161-
"min_weight": -max_hidden_weight,
162-
"max_weight": max_hidden_weight,
163-
"virtual_params": self.layer_stacks.l1_fact.weight,
164-
},
165-
{
166-
"params": [self.layer_stacks.l2.weight],
167-
"min_weight": -max_hidden_weight,
168-
"max_weight": max_hidden_weight,
169-
},
170-
{
171-
"params": [self.layer_stacks.output.weight],
172-
"min_weight": -max_out_weight,
173-
"max_weight": max_out_weight,
174-
},
175-
]
151+
self.quantization = QuantizationManager(quantize_config)
152+
self.weight_clipping = self.quantization.generate_weight_clipping_config(self)
176153

177154
self._init_layers()
178155

@@ -195,7 +172,7 @@ def _init_psqt(self):
195172
input_weights = self.input.weight
196173
input_bias = self.input.bias
197174
# 1.0 / kPonanzaConstant
198-
scale = 1 / self.nnue2score
175+
scale = 1 / self.quantization.nnue2score
199176

200177
with torch.no_grad():
201178
initial_values = self.feature_set.get_initial_psqt_features()

model/quantize.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
from dataclasses import dataclass
2+
from typing import Callable, NotRequired, TypedDict, TYPE_CHECKING
3+
4+
import torch
5+
6+
if TYPE_CHECKING:
7+
from .model import NNUEModel
8+
9+
10+
class WeightClippingConfig(TypedDict):
11+
params: list[torch.Tensor]
12+
min_weight: float
13+
max_weight: float
14+
virtual_params: NotRequired[torch.Tensor]
15+
16+
17+
@dataclass
18+
class QuantizationConfig:
19+
nnue2score: float = 600.0
20+
weight_scale_hidden: float = 64.0
21+
weight_scale_out: float = 16.0
22+
quantized_one: float = 127.0
23+
24+
25+
class QuantizationManager:
26+
def __init__(self, config: QuantizationConfig):
27+
self.nnue2score = config.nnue2score
28+
self.weight_scale_hidden = config.weight_scale_hidden
29+
self.weight_scale_out = config.weight_scale_out
30+
self.quantized_one = config.quantized_one
31+
32+
self.max_hidden_weight = self.quantized_one / self.weight_scale_hidden
33+
self.max_out_weight = (self.quantized_one * self.quantized_one) / (
34+
self.nnue2score * self.weight_scale_out
35+
)
36+
37+
def generate_weight_clipping_config(
38+
self, model: "NNUEModel"
39+
) -> list[WeightClippingConfig]:
40+
return [
41+
{
42+
"params": [model.layer_stacks.l1.weight],
43+
"min_weight": -self.max_hidden_weight,
44+
"max_weight": self.max_hidden_weight,
45+
"virtual_params": model.layer_stacks.l1_fact.weight,
46+
},
47+
{
48+
"params": [model.layer_stacks.l2.weight],
49+
"min_weight": -self.max_hidden_weight,
50+
"max_weight": self.max_hidden_weight,
51+
},
52+
{
53+
"params": [model.layer_stacks.output.weight],
54+
"min_weight": -self.max_out_weight,
55+
"max_weight": self.max_out_weight,
56+
},
57+
]
58+
59+
def quantize_feature_transformer(
60+
self,
61+
bias: torch.Tensor,
62+
weight: torch.Tensor,
63+
psqt_weight: torch.Tensor,
64+
callback: Callable = lambda *args, **kwargs: None,
65+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
66+
bias = bias.mul(self.quantized_one).round().to(torch.int16)
67+
weight = weight.mul(self.quantized_one).round().to(torch.int16)
68+
psqt_weight = (
69+
psqt_weight.mul(self.nnue2score * self.weight_scale_out)
70+
.round()
71+
.to(torch.int32)
72+
)
73+
74+
callback(bias, weight, psqt_weight)
75+
76+
return bias, weight, psqt_weight
77+
78+
def dequantize_feature_transformer(
79+
self,
80+
bias: torch.Tensor,
81+
weight: torch.Tensor,
82+
psqt_weight: torch.Tensor,
83+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
84+
bias = bias.divide(self.quantized_one)
85+
weight = weight.divide(self.quantized_one)
86+
psqt_weight = psqt_weight.divide(self.nnue2score * self.weight_scale_out)
87+
88+
return bias, weight, psqt_weight
89+
90+
def quantize_fc_layer(
91+
self,
92+
bias: torch.Tensor,
93+
weight: torch.Tensor,
94+
output_layer: bool = False,
95+
callback: Callable = lambda *args, **kwargs: None,
96+
) -> tuple[torch.Tensor, torch.Tensor]:
97+
kWeightScaleHidden = self.weight_scale_hidden
98+
kWeightScaleOut = self.nnue2score * self.weight_scale_out / self.quantized_one
99+
kWeightScale = kWeightScaleOut if output_layer else kWeightScaleHidden
100+
kBiasScaleOut = self.weight_scale_out * self.nnue2score
101+
kBiasScaleHidden = self.weight_scale_hidden * self.quantized_one
102+
kBiasScale = kBiasScaleOut if output_layer else kBiasScaleHidden
103+
kMaxWeight = self.quantized_one / kWeightScale
104+
105+
bias = bias.mul(kBiasScale).round().to(torch.int32)
106+
107+
clipped = torch.count_nonzero(weight.clamp(-kMaxWeight, kMaxWeight) - weight)
108+
total_elements = torch.numel(weight)
109+
clipped_max = torch.max(
110+
torch.abs(weight.clamp(-kMaxWeight, kMaxWeight) - weight)
111+
)
112+
113+
weight = (
114+
weight.clamp(-kMaxWeight, kMaxWeight)
115+
.mul(kWeightScale)
116+
.round()
117+
.to(torch.int8)
118+
)
119+
120+
callback(bias, weight, clipped, total_elements, clipped_max, kMaxWeight)
121+
122+
return bias, weight
123+
124+
def dequantize_fc_layer(
125+
self,
126+
bias: torch.Tensor,
127+
weight: torch.Tensor,
128+
output_layer: bool = False,
129+
) -> tuple[torch.Tensor, torch.Tensor]:
130+
kWeightScaleHidden = self.weight_scale_hidden
131+
kWeightScaleOut = self.nnue2score * self.weight_scale_out / self.quantized_one
132+
kWeightScale = kWeightScaleOut if output_layer else kWeightScaleHidden
133+
kBiasScaleOut = self.weight_scale_out * self.nnue2score
134+
kBiasScaleHidden = self.weight_scale_hidden * self.quantized_one
135+
kBiasScale = kBiasScaleOut if output_layer else kBiasScaleHidden
136+
137+
bias = bias.divide(kBiasScale)
138+
weight = weight.divide(kWeightScale)
139+
140+
return bias, weight

0 commit comments

Comments
 (0)