Skip to content

Commit 782ab62

Browse files
authored
refactor layerstack (official-stockfish#361)
1 parent f08f7ac commit 782ab62

File tree

5 files changed

+124
-115
lines changed

5 files changed

+124
-115
lines changed

ftperm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,7 @@ def eval_ft(model: NNUEModel, batch: data_loader.SparseBatchPtr) -> torch.Tensor
520520
def ft_permute_impl(model: NNUEModel, perm: npt.NDArray[np.int_]) -> None:
521521
permutation = list(perm)
522522

523-
l1_size = model.layer_stacks.l1.in_features
523+
l1_size = model.layer_stacks.l1.linear.in_features
524524
if l1_size != len(permutation) * 2:
525525
raise Exception(
526526
f"Invalid permutation size. Expected {l1_size}. Got {len(permutation) * 2}."
@@ -535,7 +535,7 @@ def ft_permute_impl(model: NNUEModel, perm: npt.NDArray[np.int_]) -> None:
535535
# Apply the permutation in place.
536536
model.input.weight.data = model.input.weight.data[:, ft_permutation]
537537
model.input.bias.data = model.input.bias.data[ft_permutation]
538-
model.layer_stacks.l1.weight.data = model.layer_stacks.l1.weight.data[
538+
model.layer_stacks.l1.linear.weight.data = model.layer_stacks.l1.linear.weight.data[
539539
:, permutation
540540
]
541541

model/lightning_module.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -125,14 +125,14 @@ def configure_optimizers(self):
125125
LR = self.lr
126126
train_params = [
127127
{"params": _get_parameters([self.model.input]), "lr": LR, "gc_dim": 0},
128-
{"params": [self.model.layer_stacks.l1_fact.weight], "lr": LR},
129-
{"params": [self.model.layer_stacks.l1_fact.bias], "lr": LR},
130-
{"params": [self.model.layer_stacks.l1.weight], "lr": LR},
131-
{"params": [self.model.layer_stacks.l1.bias], "lr": LR},
132-
{"params": [self.model.layer_stacks.l2.weight], "lr": LR},
133-
{"params": [self.model.layer_stacks.l2.bias], "lr": LR},
134-
{"params": [self.model.layer_stacks.output.weight], "lr": LR},
135-
{"params": [self.model.layer_stacks.output.bias], "lr": LR},
128+
{"params": [self.model.layer_stacks.l1.factorized_linear.weight], "lr": LR},
129+
{"params": [self.model.layer_stacks.l1.factorized_linear.bias], "lr": LR},
130+
{"params": [self.model.layer_stacks.l1.linear.weight], "lr": LR},
131+
{"params": [self.model.layer_stacks.l1.linear.bias], "lr": LR},
132+
{"params": [self.model.layer_stacks.l2.linear.weight], "lr": LR},
133+
{"params": [self.model.layer_stacks.l2.linear.bias], "lr": LR},
134+
{"params": [self.model.layer_stacks.output.linear.weight], "lr": LR},
135+
{"params": [self.model.layer_stacks.output.linear.bias], "lr": LR},
136136
]
137137

138138
optimizer = ranger21.Ranger21(

model/model.py

Lines changed: 99 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -9,92 +9,126 @@
99
from .quantize import QuantizationConfig, QuantizationManager
1010

1111

12+
class StackedLinear(nn.Module):
13+
def __init__(self, in_features: int, out_features: int, count: int):
14+
super().__init__()
15+
16+
self.in_features = in_features
17+
self.out_features = out_features
18+
self.count = count
19+
self.linear = nn.Linear(in_features, out_features * count)
20+
21+
self._init_uniformly()
22+
23+
@torch.no_grad()
24+
def _init_uniformly(self) -> None:
25+
init_weight = self.linear.weight[0 : self.out_features, :]
26+
init_bias = self.linear.bias[0 : self.out_features]
27+
28+
for i in range(1, self.count):
29+
begin = i * self.out_features
30+
end = (i + 1) * self.out_features
31+
32+
self.linear.weight[begin:end, :] = init_weight
33+
self.linear.bias[begin:end] = init_bias
34+
35+
def forward(self, x: Tensor, ls_indices: Tensor) -> Tensor:
36+
stacked_output = self.linear(x)
37+
reshaped_output = stacked_output.reshape(-1, self.out_features)
38+
39+
idx_offset = torch.arange(
40+
0, x.shape[0] * self.count, self.count, device=x.device
41+
)
42+
indices = ls_indices.flatten() + idx_offset
43+
44+
selected_output = reshaped_output[indices]
45+
46+
return selected_output
47+
48+
@torch.no_grad()
49+
def at_index(self, index: int) -> nn.Linear:
50+
layer = nn.Linear(self.in_features, self.out_features)
51+
52+
begin = index * self.out_features
53+
end = (index + 1) * self.out_features
54+
55+
layer.weight.copy_(self.linear.weight[begin:end, :])
56+
layer.bias.copy_(self.linear.bias[begin:end])
57+
58+
return layer
59+
60+
61+
class FactorizedStackedLinear(StackedLinear):
62+
def __init__(self, in_features: int, out_features: int, count: int):
63+
super().__init__(in_features, out_features, count)
64+
65+
self.factorized_linear = nn.Linear(in_features, out_features)
66+
67+
with torch.no_grad():
68+
self.factorized_linear.weight.zero_()
69+
self.factorized_linear.bias.zero_()
70+
71+
def forward(self, x: Tensor, ls_indices: Tensor) -> Tensor:
72+
stacked_output = super().forward(x, ls_indices)
73+
factorized_output = self.factorized_linear(x)
74+
75+
return stacked_output + factorized_output
76+
77+
@torch.no_grad()
78+
def at_index(self, index: int) -> nn.Linear:
79+
layer = super().at_index(index)
80+
81+
layer.weight.add_(self.factorized_linear.weight)
82+
layer.bias.add_(self.factorized_linear.bias)
83+
84+
return layer
85+
86+
@torch.no_grad()
87+
def coalesce_weights(self) -> None:
88+
for i in range(self.count):
89+
begin = i * self.out_features
90+
end = (i + 1) * self.out_features
91+
92+
self.linear.weight[begin:end, :].add_(self.factorized_linear.weight)
93+
self.linear.bias[begin:end].add_(self.factorized_linear.bias)
94+
95+
self.factorized_linear.weight.zero_()
96+
self.factorized_linear.bias.zero_()
97+
98+
1299
class LayerStacks(nn.Module):
13100
def __init__(self, count: int, config: ModelConfig):
14101
super().__init__()
15102

103+
self.count = count
16104
self.L1 = config.L1
17105
self.L2 = config.L2
18106
self.L3 = config.L3
19107

20-
self.count = count
21-
self.l1 = nn.Linear(2 * self.L1 // 2, (self.L2 + 1) * count)
22108
# Factorizer only for the first layer because later
23109
# there's a non-linearity and factorization breaks.
24110
# This is by design. The weights in the further layers should be
25111
# able to diverge a lot.
26-
self.l1_fact = nn.Linear(2 * self.L1 // 2, self.L2 + 1, bias=True)
27-
self.l2 = nn.Linear(self.L2 * 2, self.L3 * count)
28-
self.output = nn.Linear(self.L3, 1 * count)
29-
30-
self._init_layers()
31-
32-
def _init_layers(self):
33-
l1_weight = self.l1.weight
34-
l1_bias = self.l1.bias
35-
l1_fact_weight = self.l1_fact.weight
36-
l1_fact_bias = self.l1_fact.bias
37-
l2_weight = self.l2.weight
38-
l2_bias = self.l2.bias
39-
output_weight = self.output.weight
40-
output_bias = self.output.bias
112+
self.l1 = FactorizedStackedLinear(2 * self.L1 // 2, self.L2 + 1, count)
113+
self.l2 = StackedLinear(self.L2 * 2, self.L3, count)
114+
self.output = StackedLinear(self.L3, 1, count)
41115

42116
with torch.no_grad():
43-
l1_fact_weight.fill_(0.0)
44-
l1_fact_bias.fill_(0.0)
45-
output_bias.fill_(0.0)
46-
47-
for i in range(1, self.count):
48-
# Force all layer stacks to be initialized in the same way.
49-
l1_weight[i * (self.L2 + 1) : (i + 1) * (self.L2 + 1), :] = l1_weight[
50-
0 : (self.L2 + 1), :
51-
]
52-
l1_bias[i * (self.L2 + 1) : (i + 1) * (self.L2 + 1)] = l1_bias[
53-
0 : (self.L2 + 1)
54-
]
55-
l2_weight[i * self.L3 : (i + 1) * self.L3, :] = l2_weight[
56-
0 : self.L3, :
57-
]
58-
l2_bias[i * self.L3 : (i + 1) * self.L3] = l2_bias[0 : self.L3]
59-
output_weight[i : i + 1, :] = output_weight[0:1, :]
60-
61-
self.l1.weight = nn.Parameter(l1_weight)
62-
self.l1.bias = nn.Parameter(l1_bias)
63-
self.l1_fact.weight = nn.Parameter(l1_fact_weight)
64-
self.l1_fact.bias = nn.Parameter(l1_fact_bias)
65-
self.l2.weight = nn.Parameter(l2_weight)
66-
self.l2.bias = nn.Parameter(l2_bias)
67-
self.output.weight = nn.Parameter(output_weight)
68-
self.output.bias = nn.Parameter(output_bias)
117+
self.output.linear.bias.zero_()
69118

70119
def forward(self, x: Tensor, ls_indices: Tensor):
71-
idx_offset = torch.arange(
72-
0, x.shape[0] * self.count, self.count, device=x.device
73-
)
74-
75-
indices = ls_indices.flatten() + idx_offset
76-
77-
l1s_ = self.l1(x).reshape((-1, self.count, self.L2 + 1))
78-
l1f_ = self.l1_fact(x)
79-
# https://stackoverflow.com/questions/55881002/pytorch-tensor-indexing-how-to-gather-rows-by-tensor-containing-indices
80-
# basically we present it as a list of individual results and pick not only based on
81-
# the ls index but also based on batch (they are combined into one index)
82-
l1c_ = l1s_.view(-1, self.L2 + 1)[indices]
83-
l1c_, l1c_out = l1c_.split(self.L2, dim=1)
84-
l1f_, l1f_out = l1f_.split(self.L2, dim=1)
85-
l1x_ = l1c_ + l1f_
120+
l1c_ = self.l1(x, ls_indices)
121+
l1x_, l1x_out = l1c_.split(self.L2, dim=1)
86122
# multiply sqr crelu result by (127/128) to match quantized version
87123
l1x_ = torch.clamp(
88124
torch.cat([torch.pow(l1x_, 2.0) * (127 / 128), l1x_], dim=1), 0.0, 1.0
89125
)
90126

91-
l2s_ = self.l2(l1x_).reshape((-1, self.count, self.L3))
92-
l2c_ = l2s_.view(-1, self.L3)[indices]
127+
l2c_ = self.l2(l1x_, ls_indices)
93128
l2x_ = torch.clamp(l2c_, 0.0, 1.0)
94129

95-
l3s_ = self.output(l2x_).reshape((-1, self.count, 1))
96-
l3c_ = l3s_.view(-1, 1)[indices]
97-
l3x_ = l3c_ + l1f_out + l1c_out
130+
l3c_ = self.output(l2x_, ls_indices)
131+
l3x_ = l3c_ + l1x_out
98132

99133
return l3x_
100134

@@ -106,38 +140,11 @@ def get_coalesced_layer_stacks(
106140
# This representation needs to be transformed into individual layers
107141
# for the serializer, because the buckets are interpreted as separate layers.
108142
for i in range(self.count):
109-
l1 = nn.Linear(2 * self.L1 // 2, self.L2 + 1)
110-
l2 = nn.Linear(self.L2 * 2, self.L3)
111-
output = nn.Linear(self.L3, 1)
112-
l1.weight.data = (
113-
self.l1.weight[i * (self.L2 + 1) : (i + 1) * (self.L2 + 1), :]
114-
+ self.l1_fact.weight.data
115-
)
116-
l1.bias.data = (
117-
self.l1.bias[i * (self.L2 + 1) : (i + 1) * (self.L2 + 1)]
118-
+ self.l1_fact.bias.data
119-
)
120-
l2.weight.data = self.l2.weight[i * self.L3 : (i + 1) * self.L3, :]
121-
l2.bias.data = self.l2.bias[i * self.L3 : (i + 1) * self.L3]
122-
output.weight.data = self.output.weight[i : (i + 1), :]
123-
output.bias.data = self.output.bias[i : (i + 1)]
124-
yield l1, l2, output
143+
yield self.l1.at_index(i), self.l2.at_index(i), self.output.at_index(i)
125144

126145
@torch.no_grad()
127146
def coalesce_layer_stacks_inplace(self) -> None:
128-
# During training the buckets are represented by a single, wider, layer.
129-
# This representation needs to be transformed into individual layers
130-
# for the serializer, because the buckets are interpreted as separate layers.
131-
for i in range(self.count):
132-
self.l1.weight[i * (self.L2 + 1) : (i + 1) * (self.L2 + 1), :].add_(
133-
self.l1_fact.weight
134-
)
135-
self.l1.bias[i * (self.L2 + 1) : (i + 1) * (self.L2 + 1)].add_(
136-
self.l1_fact.bias
137-
)
138-
139-
self.l1_fact.weight.zero_()
140-
self.l1_fact.bias.zero_()
147+
self.l1.coalesce_weights()
141148

142149

143150
class NNUEModel(nn.Module):

model/quantize.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,18 +39,18 @@ def generate_weight_clipping_config(
3939
) -> list[WeightClippingConfig]:
4040
return [
4141
{
42-
"params": [model.layer_stacks.l1.weight],
42+
"params": [model.layer_stacks.l1.linear.weight],
4343
"min_weight": -self.max_hidden_weight,
4444
"max_weight": self.max_hidden_weight,
45-
"virtual_params": model.layer_stacks.l1_fact.weight,
45+
"virtual_params": model.layer_stacks.l1.factorized_linear.weight,
4646
},
4747
{
48-
"params": [model.layer_stacks.l2.weight],
48+
"params": [model.layer_stacks.l2.linear.weight],
4949
"min_weight": -self.max_hidden_weight,
5050
"max_weight": self.max_hidden_weight,
5151
},
5252
{
53-
"params": [model.layer_stacks.output.weight],
53+
"params": [model.layer_stacks.output.linear.weight],
5454
"min_weight": -self.max_out_weight,
5555
"max_weight": self.max_out_weight,
5656
},

model/utils/serialize.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,9 @@ def fc_hash(model: NNUEModel) -> int:
103103

104104
# Fully connected layers
105105
layers = [
106-
model.layer_stacks.l1,
107-
model.layer_stacks.l2,
108-
model.layer_stacks.output,
106+
model.layer_stacks.l1.linear,
107+
model.layer_stacks.l2.linear,
108+
model.layer_stacks.output.linear,
109109
]
110110
for layer in layers:
111111
layer_hash = 0xCC03DAE4
@@ -239,20 +239,22 @@ def __init__(
239239
self.read_fc_layer(l2)
240240
self.read_fc_layer(output, is_output=True)
241241

242-
self.model.layer_stacks.l1.weight.data[
242+
self.model.layer_stacks.l1.linear.weight.data[
243243
i * (self.config.L2 + 1) : (i + 1) * (self.config.L2 + 1), :
244244
] = l1.weight
245-
self.model.layer_stacks.l1.bias.data[
245+
self.model.layer_stacks.l1.linear.bias.data[
246246
i * (self.config.L2 + 1) : (i + 1) * (self.config.L2 + 1)
247247
] = l1.bias
248-
self.model.layer_stacks.l2.weight.data[
248+
self.model.layer_stacks.l2.linear.weight.data[
249249
i * self.config.L3 : (i + 1) * self.config.L3, :
250250
] = l2.weight
251-
self.model.layer_stacks.l2.bias.data[
251+
self.model.layer_stacks.l2.linear.bias.data[
252252
i * self.config.L3 : (i + 1) * self.config.L3
253253
] = l2.bias
254-
self.model.layer_stacks.output.weight.data[i : (i + 1), :] = output.weight
255-
self.model.layer_stacks.output.bias.data[i : (i + 1)] = output.bias
254+
self.model.layer_stacks.output.linear.weight.data[i : (i + 1), :] = (
255+
output.weight
256+
)
257+
self.model.layer_stacks.output.linear.bias.data[i : (i + 1)] = output.bias
256258

257259
def read_header(self, feature_set: FeatureSet, fc_hash: int) -> None:
258260
self.read_int32(VERSION) # version

0 commit comments

Comments
 (0)