|
2 | 2 |
|
3 | 3 | import torch |
4 | 4 | from torch import nn, Tensor |
| 5 | +import torch.nn.functional as F |
5 | 6 |
|
6 | 7 | from .config import ModelConfig |
7 | 8 | from .feature_transformer import DoubleFeatureTransformerSlice |
@@ -34,10 +35,17 @@ def _init_uniformly(self) -> None: |
34 | 35 |
|
35 | 36 | def forward(self, x: Tensor, ls_indices: Tensor) -> Tensor: |
36 | 37 | stacked_output = self.linear(x) |
| 38 | + |
| 39 | + return self.select_output(stacked_output, ls_indices) |
| 40 | + |
| 41 | + def select_output(self, stacked_output: Tensor, ls_indices: Tensor) -> Tensor: |
37 | 42 | reshaped_output = stacked_output.reshape(-1, self.out_features) |
38 | 43 |
|
39 | 44 | idx_offset = torch.arange( |
40 | | - 0, x.shape[0] * self.count, self.count, device=x.device |
| 45 | + 0, |
| 46 | + ls_indices.shape[0] * self.count, |
| 47 | + self.count, |
| 48 | + device=stacked_output.device, |
41 | 49 | ) |
42 | 50 | indices = ls_indices.flatten() + idx_offset |
43 | 51 |
|
@@ -69,10 +77,14 @@ def __init__(self, in_features: int, out_features: int, count: int): |
69 | 77 | self.factorized_linear.bias.zero_() |
70 | 78 |
|
71 | 79 | def forward(self, x: Tensor, ls_indices: Tensor) -> Tensor: |
72 | | - stacked_output = super().forward(x, ls_indices) |
73 | | - factorized_output = self.factorized_linear(x) |
| 80 | + merged_weight = self.linear.weight + self.factorized_linear.weight.repeat( |
| 81 | + self.count, 1 |
| 82 | + ) |
| 83 | + merged_bias = self.linear.bias + self.factorized_linear.bias.repeat(self.count) |
| 84 | + |
| 85 | + stacked_output = F.linear(x, merged_weight, merged_bias) |
74 | 86 |
|
75 | | - return stacked_output + factorized_output |
| 87 | + return self.select_output(stacked_output, ls_indices) |
76 | 88 |
|
77 | 89 | @torch.no_grad() |
78 | 90 | def at_index(self, index: int) -> nn.Linear: |
|
0 commit comments