Skip to content

Commit 187f1c1

Browse files
authored
fix ftperm from checkpoint (official-stockfish#359)
* fix ftperm from checkpoint closes official-stockfish#322 * fix bug * also coalesce layerstacks * fix bug
1 parent fc889c7 commit 187f1c1

File tree

8 files changed

+76
-27
lines changed

8 files changed

+76
-27
lines changed

model/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,13 @@
44
from .lightning_module import NNUE
55
from .model import NNUEModel
66
from .quantize import QuantizationConfig
7-
from .utils import coalesce_ft_weights, load_model, NNUEReader, NNUEWriter
7+
from .utils import (
8+
coalesce_ft_weights,
9+
coalesce_ft_weights_inplace,
10+
load_model,
11+
NNUEReader,
12+
NNUEWriter,
13+
)
814

915

1016
__all__ = [
@@ -18,6 +24,7 @@
1824
"NNUEModel",
1925
"QuantizationConfig",
2026
"coalesce_ft_weights",
27+
"coalesce_ft_weights_inplace",
2128
"load_model",
2229
"NNUEReader",
2330
"NNUEWriter",

model/model.py

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -98,30 +98,46 @@ def forward(self, x: Tensor, ls_indices: Tensor):
9898

9999
return l3x_
100100

101+
@torch.no_grad()
101102
def get_coalesced_layer_stacks(
102103
self,
103104
) -> Generator[tuple[nn.Linear, nn.Linear, nn.Linear], None, None]:
104105
# During training the buckets are represented by a single, wider, layer.
105106
# This representation needs to be transformed into individual layers
106107
# for the serializer, because the buckets are interpreted as separate layers.
107108
for i in range(self.count):
108-
with torch.no_grad():
109-
l1 = nn.Linear(2 * self.L1 // 2, self.L2 + 1)
110-
l2 = nn.Linear(self.L2 * 2, self.L3)
111-
output = nn.Linear(self.L3, 1)
112-
l1.weight.data = (
113-
self.l1.weight[i * (self.L2 + 1) : (i + 1) * (self.L2 + 1), :]
114-
+ self.l1_fact.weight.data
115-
)
116-
l1.bias.data = (
117-
self.l1.bias[i * (self.L2 + 1) : (i + 1) * (self.L2 + 1)]
118-
+ self.l1_fact.bias.data
119-
)
120-
l2.weight.data = self.l2.weight[i * self.L3 : (i + 1) * self.L3, :]
121-
l2.bias.data = self.l2.bias[i * self.L3 : (i + 1) * self.L3]
122-
output.weight.data = self.output.weight[i : (i + 1), :]
123-
output.bias.data = self.output.bias[i : (i + 1)]
124-
yield l1, l2, output
109+
l1 = nn.Linear(2 * self.L1 // 2, self.L2 + 1)
110+
l2 = nn.Linear(self.L2 * 2, self.L3)
111+
output = nn.Linear(self.L3, 1)
112+
l1.weight.data = (
113+
self.l1.weight[i * (self.L2 + 1) : (i + 1) * (self.L2 + 1), :]
114+
+ self.l1_fact.weight.data
115+
)
116+
l1.bias.data = (
117+
self.l1.bias[i * (self.L2 + 1) : (i + 1) * (self.L2 + 1)]
118+
+ self.l1_fact.bias.data
119+
)
120+
l2.weight.data = self.l2.weight[i * self.L3 : (i + 1) * self.L3, :]
121+
l2.bias.data = self.l2.bias[i * self.L3 : (i + 1) * self.L3]
122+
output.weight.data = self.output.weight[i : (i + 1), :]
123+
output.bias.data = self.output.bias[i : (i + 1)]
124+
yield l1, l2, output
125+
126+
@torch.no_grad()
127+
def coalesce_layer_stacks_inplace(self) -> None:
128+
# During training the buckets are represented by a single, wider, layer.
129+
# This representation needs to be transformed into individual layers
130+
# for the serializer, because the buckets are interpreted as separate layers.
131+
for i in range(self.count):
132+
self.l1.weight[i * (self.L2 + 1) : (i + 1) * (self.L2 + 1), :].add_(
133+
self.l1_fact.weight
134+
)
135+
self.l1.bias[i * (self.L2 + 1) : (i + 1) * (self.L2 + 1)].add_(
136+
self.l1_fact.bias
137+
)
138+
139+
self.l1_fact.weight.zero_()
140+
self.l1_fact.bias.zero_()
125141

126142

127143
class NNUEModel(nn.Module):

model/utils/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
from .coalesce_weights import coalesce_ft_weights
1+
from .coalesce_weights import coalesce_ft_weights, coalesce_ft_weights_inplace
22
from .load_model import load_model
33
from .serialize import NNUEReader, NNUEWriter
44

55

66
__all__ = [
77
"coalesce_ft_weights",
8+
"coalesce_ft_weights_inplace",
89
"load_model",
910
"NNUEReader",
1011
"NNUEWriter",

model/utils/coalesce_weights.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,32 @@
1-
from ..model import NNUEModel
1+
import torch
2+
3+
from ..features import FeatureSet
24
from ..feature_transformer import BaseFeatureTransformerSlice
35

46

5-
def coalesce_ft_weights(model: NNUEModel, layer: BaseFeatureTransformerSlice):
7+
def coalesce_ft_weights(
8+
feature_set: FeatureSet, layer: BaseFeatureTransformerSlice
9+
) -> torch.Tensor:
610
weight = layer.weight.data
7-
indices = model.feature_set.get_virtual_to_real_features_gather_indices()
11+
indices = feature_set.get_virtual_to_real_features_gather_indices()
812
weight_coalesced = weight.new_zeros(
9-
(model.feature_set.num_real_features, weight.shape[1])
13+
(feature_set.num_real_features, weight.shape[1])
1014
)
1115
for i_real, is_virtual in enumerate(indices):
1216
weight_coalesced[i_real, :] = sum(
1317
weight[i_virtual, :] for i_virtual in is_virtual
1418
)
1519
return weight_coalesced
20+
21+
22+
def coalesce_ft_weights_inplace(
23+
feature_set: FeatureSet, layer: BaseFeatureTransformerSlice
24+
) -> None:
25+
weight = layer.weight.data
26+
indices = feature_set.get_virtual_to_real_features_gather_indices()
27+
weight_coalesced = torch.zeros_like(weight)
28+
for i_real, is_virtual in enumerate(indices):
29+
weight_coalesced[i_real, :] = sum(
30+
weight[i_virtual, :] for i_virtual in is_virtual
31+
)
32+
layer.weight.data = weight_coalesced

model/utils/serialize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def write_feature_transformer(self, model: NNUEModel, ft_compression: str) -> No
146146

147147
bias = layer.bias.data[: model.L1]
148148

149-
all_weight = coalesce_ft_weights(model, layer)
149+
all_weight = coalesce_ft_weights(model.feature_set, layer)
150150
weight = all_weight[:, : model.L1]
151151
psqt_weight = all_weight[:, model.L1 :]
152152

serialize.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,10 @@ def main():
136136
if args.device is not None:
137137
ftperm.set_cupy_device(args.device)
138138

139+
if not args.source.endswith(".nnue"):
140+
M.coalesce_ft_weights_inplace(nnue.model.feature_set, nnue.model.input)
141+
nnue.model.layer_stacks.coalesce_layer_stacks_inplace()
142+
139143
ftperm.ft_optimize(
140144
nnue.model,
141145
args.ft_optimize_data,

visualize.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,14 @@ def _process_fig(self, name, fig=None):
3939

4040
def plot_input_weights(self):
4141
# Coalesce weights and transform them to Numpy domain.
42-
weights = M.coalesce_ft_weights(self.model, self.model.input)
42+
weights = M.coalesce_ft_weights(self.model.feature_set, self.model.input)
4343
weights = weights[:, : self.model.L1]
4444
weights = weights.flatten().numpy()
4545

4646
if self.args.ref_model:
47-
ref_weights = M.coalesce_ft_weights(self.ref_model, self.ref_model.input)
47+
ref_weights = M.coalesce_ft_weights(
48+
self.ref_model.feature_set, self.ref_model.input
49+
)
4850
ref_weights = ref_weights[:, : self.model.L1]
4951
ref_weights = ref_weights.flatten().numpy()
5052
weights -= ref_weights

visualize_multi_hist.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,9 @@ def main():
8888
for m in args.models
8989
]
9090

91-
coalesced_ins = [M.coalesce_ft_weights(model, model.input) for model in models]
91+
coalesced_ins = [
92+
M.coalesce_ft_weights(model.feature_set, model.input) for model in models
93+
]
9294
input_weights = [
9395
coalesced_in[:, : args.l1].flatten().numpy() for coalesced_in in coalesced_ins
9496
]

0 commit comments

Comments
 (0)