Skip to content

Commit 3e21b51

Browse files
authored
Modularize layers (official-stockfish#370)
* modularize * ruff format * rename layers to modules * Split off test code * fix * format * fix some issues * fix unintended quantization change
1 parent f8364be commit 3e21b51

File tree

12 files changed

+943
-907
lines changed

12 files changed

+943
-907
lines changed

model/feature_transformer.py

Lines changed: 0 additions & 747 deletions
This file was deleted.

model/model.py

Lines changed: 10 additions & 158 deletions
Original file line numberDiff line numberDiff line change
@@ -1,160 +1,12 @@
1-
from typing import Generator
2-
31
import torch
4-
from torch import nn, Tensor
5-
import torch.nn.functional as F
2+
from torch import nn
63

74
from .config import ModelConfig
8-
from .feature_transformer import DoubleFeatureTransformerSlice
95
from .features import FeatureSet
6+
from .modules import DoubleFeatureTransformerSlice, LayerStacks
107
from .quantize import QuantizationConfig, QuantizationManager
118

129

13-
class StackedLinear(nn.Module):
14-
def __init__(self, in_features: int, out_features: int, count: int):
15-
super().__init__()
16-
17-
self.in_features = in_features
18-
self.out_features = out_features
19-
self.count = count
20-
self.linear = nn.Linear(in_features, out_features * count)
21-
22-
self._init_uniformly()
23-
24-
@torch.no_grad()
25-
def _init_uniformly(self) -> None:
26-
init_weight = self.linear.weight[0 : self.out_features, :]
27-
init_bias = self.linear.bias[0 : self.out_features]
28-
29-
self.linear.weight.copy_(init_weight.repeat(self.count, 1))
30-
self.linear.bias.copy_(init_bias.repeat(self.count))
31-
32-
def forward(self, x: Tensor, ls_indices: Tensor) -> Tensor:
33-
stacked_output = self.linear(x)
34-
35-
return self.select_output(stacked_output, ls_indices)
36-
37-
def select_output(self, stacked_output: Tensor, ls_indices: Tensor) -> Tensor:
38-
reshaped_output = stacked_output.reshape(-1, self.out_features)
39-
40-
idx_offset = torch.arange(
41-
0,
42-
ls_indices.shape[0] * self.count,
43-
self.count,
44-
device=stacked_output.device,
45-
)
46-
indices = ls_indices.flatten() + idx_offset
47-
48-
selected_output = reshaped_output[indices]
49-
50-
return selected_output
51-
52-
@torch.no_grad()
53-
def at_index(self, index: int) -> nn.Linear:
54-
layer = nn.Linear(self.in_features, self.out_features)
55-
56-
begin = index * self.out_features
57-
end = (index + 1) * self.out_features
58-
59-
layer.weight.copy_(self.linear.weight[begin:end, :])
60-
layer.bias.copy_(self.linear.bias[begin:end])
61-
62-
return layer
63-
64-
65-
class FactorizedStackedLinear(StackedLinear):
66-
def __init__(self, in_features: int, out_features: int, count: int):
67-
super().__init__(in_features, out_features, count)
68-
69-
self.factorized_linear = nn.Linear(in_features, out_features)
70-
71-
with torch.no_grad():
72-
self.factorized_linear.weight.zero_()
73-
self.factorized_linear.bias.zero_()
74-
75-
def forward(self, x: Tensor, ls_indices: Tensor) -> Tensor:
76-
merged_weight = self.linear.weight + self.factorized_linear.weight.repeat(
77-
self.count, 1
78-
)
79-
merged_bias = self.linear.bias + self.factorized_linear.bias.repeat(self.count)
80-
81-
stacked_output = F.linear(x, merged_weight, merged_bias)
82-
83-
return self.select_output(stacked_output, ls_indices)
84-
85-
@torch.no_grad()
86-
def at_index(self, index: int) -> nn.Linear:
87-
layer = super().at_index(index)
88-
89-
layer.weight.add_(self.factorized_linear.weight)
90-
layer.bias.add_(self.factorized_linear.bias)
91-
92-
return layer
93-
94-
@torch.no_grad()
95-
def coalesce_weights(self) -> None:
96-
for i in range(self.count):
97-
begin = i * self.out_features
98-
end = (i + 1) * self.out_features
99-
100-
self.linear.weight[begin:end, :].add_(self.factorized_linear.weight)
101-
self.linear.bias[begin:end].add_(self.factorized_linear.bias)
102-
103-
self.factorized_linear.weight.zero_()
104-
self.factorized_linear.bias.zero_()
105-
106-
107-
class LayerStacks(nn.Module):
108-
def __init__(self, count: int, config: ModelConfig):
109-
super().__init__()
110-
111-
self.count = count
112-
self.L1 = config.L1
113-
self.L2 = config.L2
114-
self.L3 = config.L3
115-
116-
# Factorizer only for the first layer because later
117-
# there's a non-linearity and factorization breaks.
118-
# This is by design. The weights in the further layers should be
119-
# able to diverge a lot.
120-
self.l1 = FactorizedStackedLinear(2 * self.L1 // 2, self.L2 + 1, count)
121-
self.l2 = StackedLinear(self.L2 * 2, self.L3, count)
122-
self.output = StackedLinear(self.L3, 1, count)
123-
124-
with torch.no_grad():
125-
self.output.linear.bias.zero_()
126-
127-
def forward(self, x: Tensor, ls_indices: Tensor):
128-
l1c_ = self.l1(x, ls_indices)
129-
l1x_, l1x_out = l1c_.split(self.L2, dim=1)
130-
# multiply sqr crelu result by (127/128) to match quantized version
131-
l1x_ = torch.clamp(
132-
torch.cat([torch.pow(l1x_, 2.0) * (127 / 128), l1x_], dim=1), 0.0, 1.0
133-
)
134-
135-
l2c_ = self.l2(l1x_, ls_indices)
136-
l2x_ = torch.clamp(l2c_, 0.0, 1.0)
137-
138-
l3c_ = self.output(l2x_, ls_indices)
139-
l3x_ = l3c_ + l1x_out
140-
141-
return l3x_
142-
143-
@torch.no_grad()
144-
def get_coalesced_layer_stacks(
145-
self,
146-
) -> Generator[tuple[nn.Linear, nn.Linear, nn.Linear], None, None]:
147-
# During training the buckets are represented by a single, wider, layer.
148-
# This representation needs to be transformed into individual layers
149-
# for the serializer, because the buckets are interpreted as separate layers.
150-
for i in range(self.count):
151-
yield self.l1.at_index(i), self.l2.at_index(i), self.output.at_index(i)
152-
153-
@torch.no_grad()
154-
def coalesce_layer_stacks_inplace(self) -> None:
155-
self.l1.coalesce_weights()
156-
157-
15810
class NNUEModel(nn.Module):
15911
def __init__(
16012
self,
@@ -315,14 +167,14 @@ def set_feature_set(self, new_feature_set: FeatureSet):
315167

316168
def forward(
317169
self,
318-
us: Tensor,
319-
them: Tensor,
320-
white_indices: Tensor,
321-
white_values: Tensor,
322-
black_indices: Tensor,
323-
black_values: Tensor,
324-
psqt_indices: Tensor,
325-
layer_stack_indices: Tensor,
170+
us: torch.Tensor,
171+
them: torch.Tensor,
172+
white_indices: torch.Tensor,
173+
white_values: torch.Tensor,
174+
black_indices: torch.Tensor,
175+
black_values: torch.Tensor,
176+
psqt_indices: torch.Tensor,
177+
layer_stack_indices: torch.Tensor,
326178
):
327179
wp, bp = self.input(white_indices, white_values, black_indices, black_values)
328180
w, wpsqt = torch.split(wp, self.L1, dim=1)

model/modules/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from .feature_transformer import (
2+
BaseFeatureTransformerSlice,
3+
DoubleFeatureTransformerSlice,
4+
FeatureTransformerSlice,
5+
)
6+
from .layer_stacks import LayerStacks
7+
8+
__all__ = [
9+
"BaseFeatureTransformerSlice",
10+
"DoubleFeatureTransformerSlice",
11+
"FeatureTransformerSlice",
12+
"LayerStacks",
13+
]
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from .module import (
2+
BaseFeatureTransformerSlice,
3+
DoubleFeatureTransformerSlice,
4+
FeatureTransformerSlice,
5+
)
6+
7+
__all__ = [
8+
"BaseFeatureTransformerSlice",
9+
"DoubleFeatureTransformerSlice",
10+
"FeatureTransformerSlice",
11+
]

0 commit comments

Comments
 (0)