Skip to content

Commit bc0cfca

Browse files
authored
Merge Factorizer on LayerStacks forward (official-stockfish#364)
* merge factorizer before forward * ? * ruff format
1 parent 12006e1 commit bc0cfca

File tree

1 file changed

+16
-4
lines changed

1 file changed

+16
-4
lines changed

model/model.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import torch
44
from torch import nn, Tensor
5+
import torch.nn.functional as F
56

67
from .config import ModelConfig
78
from .feature_transformer import DoubleFeatureTransformerSlice
@@ -34,10 +35,17 @@ def _init_uniformly(self) -> None:
3435

3536
def forward(self, x: Tensor, ls_indices: Tensor) -> Tensor:
3637
stacked_output = self.linear(x)
38+
39+
return self.select_output(stacked_output, ls_indices)
40+
41+
def select_output(self, stacked_output: Tensor, ls_indices: Tensor) -> Tensor:
3742
reshaped_output = stacked_output.reshape(-1, self.out_features)
3843

3944
idx_offset = torch.arange(
40-
0, x.shape[0] * self.count, self.count, device=x.device
45+
0,
46+
ls_indices.shape[0] * self.count,
47+
self.count,
48+
device=stacked_output.device,
4149
)
4250
indices = ls_indices.flatten() + idx_offset
4351

@@ -69,10 +77,14 @@ def __init__(self, in_features: int, out_features: int, count: int):
6977
self.factorized_linear.bias.zero_()
7078

7179
def forward(self, x: Tensor, ls_indices: Tensor) -> Tensor:
72-
stacked_output = super().forward(x, ls_indices)
73-
factorized_output = self.factorized_linear(x)
80+
merged_weight = self.linear.weight + self.factorized_linear.weight.repeat(
81+
self.count, 1
82+
)
83+
merged_bias = self.linear.bias + self.factorized_linear.bias.repeat(self.count)
84+
85+
stacked_output = F.linear(x, merged_weight, merged_bias)
7486

75-
return stacked_output + factorized_output
87+
return self.select_output(stacked_output, ls_indices)
7688

7789
@torch.no_grad()
7890
def at_index(self, index: int) -> nn.Linear:

0 commit comments

Comments
 (0)