|
1 | | -from typing import Generator |
2 | | - |
3 | 1 | import torch |
4 | | -from torch import nn, Tensor |
5 | | -import torch.nn.functional as F |
| 2 | +from torch import nn |
6 | 3 |
|
7 | 4 | from .config import ModelConfig |
8 | | -from .feature_transformer import DoubleFeatureTransformerSlice |
9 | 5 | from .features import FeatureSet |
| 6 | +from .modules import DoubleFeatureTransformerSlice, LayerStacks |
10 | 7 | from .quantize import QuantizationConfig, QuantizationManager |
11 | 8 |
|
12 | 9 |
|
13 | | -class StackedLinear(nn.Module): |
14 | | - def __init__(self, in_features: int, out_features: int, count: int): |
15 | | - super().__init__() |
16 | | - |
17 | | - self.in_features = in_features |
18 | | - self.out_features = out_features |
19 | | - self.count = count |
20 | | - self.linear = nn.Linear(in_features, out_features * count) |
21 | | - |
22 | | - self._init_uniformly() |
23 | | - |
24 | | - @torch.no_grad() |
25 | | - def _init_uniformly(self) -> None: |
26 | | - init_weight = self.linear.weight[0 : self.out_features, :] |
27 | | - init_bias = self.linear.bias[0 : self.out_features] |
28 | | - |
29 | | - self.linear.weight.copy_(init_weight.repeat(self.count, 1)) |
30 | | - self.linear.bias.copy_(init_bias.repeat(self.count)) |
31 | | - |
32 | | - def forward(self, x: Tensor, ls_indices: Tensor) -> Tensor: |
33 | | - stacked_output = self.linear(x) |
34 | | - |
35 | | - return self.select_output(stacked_output, ls_indices) |
36 | | - |
37 | | - def select_output(self, stacked_output: Tensor, ls_indices: Tensor) -> Tensor: |
38 | | - reshaped_output = stacked_output.reshape(-1, self.out_features) |
39 | | - |
40 | | - idx_offset = torch.arange( |
41 | | - 0, |
42 | | - ls_indices.shape[0] * self.count, |
43 | | - self.count, |
44 | | - device=stacked_output.device, |
45 | | - ) |
46 | | - indices = ls_indices.flatten() + idx_offset |
47 | | - |
48 | | - selected_output = reshaped_output[indices] |
49 | | - |
50 | | - return selected_output |
51 | | - |
52 | | - @torch.no_grad() |
53 | | - def at_index(self, index: int) -> nn.Linear: |
54 | | - layer = nn.Linear(self.in_features, self.out_features) |
55 | | - |
56 | | - begin = index * self.out_features |
57 | | - end = (index + 1) * self.out_features |
58 | | - |
59 | | - layer.weight.copy_(self.linear.weight[begin:end, :]) |
60 | | - layer.bias.copy_(self.linear.bias[begin:end]) |
61 | | - |
62 | | - return layer |
63 | | - |
64 | | - |
65 | | -class FactorizedStackedLinear(StackedLinear): |
66 | | - def __init__(self, in_features: int, out_features: int, count: int): |
67 | | - super().__init__(in_features, out_features, count) |
68 | | - |
69 | | - self.factorized_linear = nn.Linear(in_features, out_features) |
70 | | - |
71 | | - with torch.no_grad(): |
72 | | - self.factorized_linear.weight.zero_() |
73 | | - self.factorized_linear.bias.zero_() |
74 | | - |
75 | | - def forward(self, x: Tensor, ls_indices: Tensor) -> Tensor: |
76 | | - merged_weight = self.linear.weight + self.factorized_linear.weight.repeat( |
77 | | - self.count, 1 |
78 | | - ) |
79 | | - merged_bias = self.linear.bias + self.factorized_linear.bias.repeat(self.count) |
80 | | - |
81 | | - stacked_output = F.linear(x, merged_weight, merged_bias) |
82 | | - |
83 | | - return self.select_output(stacked_output, ls_indices) |
84 | | - |
85 | | - @torch.no_grad() |
86 | | - def at_index(self, index: int) -> nn.Linear: |
87 | | - layer = super().at_index(index) |
88 | | - |
89 | | - layer.weight.add_(self.factorized_linear.weight) |
90 | | - layer.bias.add_(self.factorized_linear.bias) |
91 | | - |
92 | | - return layer |
93 | | - |
94 | | - @torch.no_grad() |
95 | | - def coalesce_weights(self) -> None: |
96 | | - for i in range(self.count): |
97 | | - begin = i * self.out_features |
98 | | - end = (i + 1) * self.out_features |
99 | | - |
100 | | - self.linear.weight[begin:end, :].add_(self.factorized_linear.weight) |
101 | | - self.linear.bias[begin:end].add_(self.factorized_linear.bias) |
102 | | - |
103 | | - self.factorized_linear.weight.zero_() |
104 | | - self.factorized_linear.bias.zero_() |
105 | | - |
106 | | - |
107 | | -class LayerStacks(nn.Module): |
108 | | - def __init__(self, count: int, config: ModelConfig): |
109 | | - super().__init__() |
110 | | - |
111 | | - self.count = count |
112 | | - self.L1 = config.L1 |
113 | | - self.L2 = config.L2 |
114 | | - self.L3 = config.L3 |
115 | | - |
116 | | - # Factorizer only for the first layer because later |
117 | | - # there's a non-linearity and factorization breaks. |
118 | | - # This is by design. The weights in the further layers should be |
119 | | - # able to diverge a lot. |
120 | | - self.l1 = FactorizedStackedLinear(2 * self.L1 // 2, self.L2 + 1, count) |
121 | | - self.l2 = StackedLinear(self.L2 * 2, self.L3, count) |
122 | | - self.output = StackedLinear(self.L3, 1, count) |
123 | | - |
124 | | - with torch.no_grad(): |
125 | | - self.output.linear.bias.zero_() |
126 | | - |
127 | | - def forward(self, x: Tensor, ls_indices: Tensor): |
128 | | - l1c_ = self.l1(x, ls_indices) |
129 | | - l1x_, l1x_out = l1c_.split(self.L2, dim=1) |
130 | | - # multiply sqr crelu result by (127/128) to match quantized version |
131 | | - l1x_ = torch.clamp( |
132 | | - torch.cat([torch.pow(l1x_, 2.0) * (127 / 128), l1x_], dim=1), 0.0, 1.0 |
133 | | - ) |
134 | | - |
135 | | - l2c_ = self.l2(l1x_, ls_indices) |
136 | | - l2x_ = torch.clamp(l2c_, 0.0, 1.0) |
137 | | - |
138 | | - l3c_ = self.output(l2x_, ls_indices) |
139 | | - l3x_ = l3c_ + l1x_out |
140 | | - |
141 | | - return l3x_ |
142 | | - |
143 | | - @torch.no_grad() |
144 | | - def get_coalesced_layer_stacks( |
145 | | - self, |
146 | | - ) -> Generator[tuple[nn.Linear, nn.Linear, nn.Linear], None, None]: |
147 | | - # During training the buckets are represented by a single, wider, layer. |
148 | | - # This representation needs to be transformed into individual layers |
149 | | - # for the serializer, because the buckets are interpreted as separate layers. |
150 | | - for i in range(self.count): |
151 | | - yield self.l1.at_index(i), self.l2.at_index(i), self.output.at_index(i) |
152 | | - |
153 | | - @torch.no_grad() |
154 | | - def coalesce_layer_stacks_inplace(self) -> None: |
155 | | - self.l1.coalesce_weights() |
156 | | - |
157 | | - |
158 | 10 | class NNUEModel(nn.Module): |
159 | 11 | def __init__( |
160 | 12 | self, |
@@ -315,14 +167,14 @@ def set_feature_set(self, new_feature_set: FeatureSet): |
315 | 167 |
|
316 | 168 | def forward( |
317 | 169 | self, |
318 | | - us: Tensor, |
319 | | - them: Tensor, |
320 | | - white_indices: Tensor, |
321 | | - white_values: Tensor, |
322 | | - black_indices: Tensor, |
323 | | - black_values: Tensor, |
324 | | - psqt_indices: Tensor, |
325 | | - layer_stack_indices: Tensor, |
| 170 | + us: torch.Tensor, |
| 171 | + them: torch.Tensor, |
| 172 | + white_indices: torch.Tensor, |
| 173 | + white_values: torch.Tensor, |
| 174 | + black_indices: torch.Tensor, |
| 175 | + black_values: torch.Tensor, |
| 176 | + psqt_indices: torch.Tensor, |
| 177 | + layer_stack_indices: torch.Tensor, |
326 | 178 | ): |
327 | 179 | wp, bp = self.input(white_indices, white_values, black_indices, black_values) |
328 | 180 | w, wpsqt = torch.split(wp, self.L1, dim=1) |
|
0 commit comments