99from .quantize import QuantizationConfig , QuantizationManager
1010
1111
12+ class StackedLinear (nn .Module ):
13+ def __init__ (self , in_features : int , out_features : int , count : int ):
14+ super ().__init__ ()
15+
16+ self .in_features = in_features
17+ self .out_features = out_features
18+ self .count = count
19+ self .linear = nn .Linear (in_features , out_features * count )
20+
21+ self ._init_uniformly ()
22+
23+ @torch .no_grad ()
24+ def _init_uniformly (self ) -> None :
25+ init_weight = self .linear .weight [0 : self .out_features , :]
26+ init_bias = self .linear .bias [0 : self .out_features ]
27+
28+ for i in range (1 , self .count ):
29+ begin = i * self .out_features
30+ end = (i + 1 ) * self .out_features
31+
32+ self .linear .weight [begin :end , :] = init_weight
33+ self .linear .bias [begin :end ] = init_bias
34+
35+ def forward (self , x : Tensor , ls_indices : Tensor ) -> Tensor :
36+ stacked_output = self .linear (x )
37+ reshaped_output = stacked_output .reshape (- 1 , self .out_features )
38+
39+ idx_offset = torch .arange (
40+ 0 , x .shape [0 ] * self .count , self .count , device = x .device
41+ )
42+ indices = ls_indices .flatten () + idx_offset
43+
44+ selected_output = reshaped_output [indices ]
45+
46+ return selected_output
47+
48+ @torch .no_grad ()
49+ def at_index (self , index : int ) -> nn .Linear :
50+ layer = nn .Linear (self .in_features , self .out_features )
51+
52+ begin = index * self .out_features
53+ end = (index + 1 ) * self .out_features
54+
55+ layer .weight .copy_ (self .linear .weight [begin :end , :])
56+ layer .bias .copy_ (self .linear .bias [begin :end ])
57+
58+ return layer
59+
60+
61+ class FactorizedStackedLinear (StackedLinear ):
62+ def __init__ (self , in_features : int , out_features : int , count : int ):
63+ super ().__init__ (in_features , out_features , count )
64+
65+ self .factorized_linear = nn .Linear (in_features , out_features )
66+
67+ with torch .no_grad ():
68+ self .factorized_linear .weight .zero_ ()
69+ self .factorized_linear .bias .zero_ ()
70+
71+ def forward (self , x : Tensor , ls_indices : Tensor ) -> Tensor :
72+ stacked_output = super ().forward (x , ls_indices )
73+ factorized_output = self .factorized_linear (x )
74+
75+ return stacked_output + factorized_output
76+
77+ @torch .no_grad ()
78+ def at_index (self , index : int ) -> nn .Linear :
79+ layer = super ().at_index (index )
80+
81+ layer .weight .add_ (self .factorized_linear .weight )
82+ layer .bias .add_ (self .factorized_linear .bias )
83+
84+ return layer
85+
86+ @torch .no_grad ()
87+ def coalesce_weights (self ) -> None :
88+ for i in range (self .count ):
89+ begin = i * self .out_features
90+ end = (i + 1 ) * self .out_features
91+
92+ self .linear .weight [begin :end , :].add_ (self .factorized_linear .weight )
93+ self .linear .bias [begin :end ].add_ (self .factorized_linear .bias )
94+
95+ self .factorized_linear .weight .zero_ ()
96+ self .factorized_linear .bias .zero_ ()
97+
98+
1299class LayerStacks (nn .Module ):
13100 def __init__ (self , count : int , config : ModelConfig ):
14101 super ().__init__ ()
15102
103+ self .count = count
16104 self .L1 = config .L1
17105 self .L2 = config .L2
18106 self .L3 = config .L3
19107
20- self .count = count
21- self .l1 = nn .Linear (2 * self .L1 // 2 , (self .L2 + 1 ) * count )
22108 # Factorizer only for the first layer because later
23109 # there's a non-linearity and factorization breaks.
24110 # This is by design. The weights in the further layers should be
25111 # able to diverge a lot.
26- self .l1_fact = nn .Linear (2 * self .L1 // 2 , self .L2 + 1 , bias = True )
27- self .l2 = nn .Linear (self .L2 * 2 , self .L3 * count )
28- self .output = nn .Linear (self .L3 , 1 * count )
29-
30- self ._init_layers ()
31-
32- def _init_layers (self ):
33- l1_weight = self .l1 .weight
34- l1_bias = self .l1 .bias
35- l1_fact_weight = self .l1_fact .weight
36- l1_fact_bias = self .l1_fact .bias
37- l2_weight = self .l2 .weight
38- l2_bias = self .l2 .bias
39- output_weight = self .output .weight
40- output_bias = self .output .bias
112+ self .l1 = FactorizedStackedLinear (2 * self .L1 // 2 , self .L2 + 1 , count )
113+ self .l2 = StackedLinear (self .L2 * 2 , self .L3 , count )
114+ self .output = StackedLinear (self .L3 , 1 , count )
41115
42116 with torch .no_grad ():
43- l1_fact_weight .fill_ (0.0 )
44- l1_fact_bias .fill_ (0.0 )
45- output_bias .fill_ (0.0 )
46-
47- for i in range (1 , self .count ):
48- # Force all layer stacks to be initialized in the same way.
49- l1_weight [i * (self .L2 + 1 ) : (i + 1 ) * (self .L2 + 1 ), :] = l1_weight [
50- 0 : (self .L2 + 1 ), :
51- ]
52- l1_bias [i * (self .L2 + 1 ) : (i + 1 ) * (self .L2 + 1 )] = l1_bias [
53- 0 : (self .L2 + 1 )
54- ]
55- l2_weight [i * self .L3 : (i + 1 ) * self .L3 , :] = l2_weight [
56- 0 : self .L3 , :
57- ]
58- l2_bias [i * self .L3 : (i + 1 ) * self .L3 ] = l2_bias [0 : self .L3 ]
59- output_weight [i : i + 1 , :] = output_weight [0 :1 , :]
60-
61- self .l1 .weight = nn .Parameter (l1_weight )
62- self .l1 .bias = nn .Parameter (l1_bias )
63- self .l1_fact .weight = nn .Parameter (l1_fact_weight )
64- self .l1_fact .bias = nn .Parameter (l1_fact_bias )
65- self .l2 .weight = nn .Parameter (l2_weight )
66- self .l2 .bias = nn .Parameter (l2_bias )
67- self .output .weight = nn .Parameter (output_weight )
68- self .output .bias = nn .Parameter (output_bias )
117+ self .output .linear .bias .zero_ ()
69118
70119 def forward (self , x : Tensor , ls_indices : Tensor ):
71- idx_offset = torch .arange (
72- 0 , x .shape [0 ] * self .count , self .count , device = x .device
73- )
74-
75- indices = ls_indices .flatten () + idx_offset
76-
77- l1s_ = self .l1 (x ).reshape ((- 1 , self .count , self .L2 + 1 ))
78- l1f_ = self .l1_fact (x )
79- # https://stackoverflow.com/questions/55881002/pytorch-tensor-indexing-how-to-gather-rows-by-tensor-containing-indices
80- # basically we present it as a list of individual results and pick not only based on
81- # the ls index but also based on batch (they are combined into one index)
82- l1c_ = l1s_ .view (- 1 , self .L2 + 1 )[indices ]
83- l1c_ , l1c_out = l1c_ .split (self .L2 , dim = 1 )
84- l1f_ , l1f_out = l1f_ .split (self .L2 , dim = 1 )
85- l1x_ = l1c_ + l1f_
120+ l1c_ = self .l1 (x , ls_indices )
121+ l1x_ , l1x_out = l1c_ .split (self .L2 , dim = 1 )
86122 # multiply sqr crelu result by (127/128) to match quantized version
87123 l1x_ = torch .clamp (
88124 torch .cat ([torch .pow (l1x_ , 2.0 ) * (127 / 128 ), l1x_ ], dim = 1 ), 0.0 , 1.0
89125 )
90126
91- l2s_ = self .l2 (l1x_ ).reshape ((- 1 , self .count , self .L3 ))
92- l2c_ = l2s_ .view (- 1 , self .L3 )[indices ]
127+ l2c_ = self .l2 (l1x_ , ls_indices )
93128 l2x_ = torch .clamp (l2c_ , 0.0 , 1.0 )
94129
95- l3s_ = self .output (l2x_ ).reshape ((- 1 , self .count , 1 ))
96- l3c_ = l3s_ .view (- 1 , 1 )[indices ]
97- l3x_ = l3c_ + l1f_out + l1c_out
130+ l3c_ = self .output (l2x_ , ls_indices )
131+ l3x_ = l3c_ + l1x_out
98132
99133 return l3x_
100134
@@ -106,38 +140,11 @@ def get_coalesced_layer_stacks(
106140 # This representation needs to be transformed into individual layers
107141 # for the serializer, because the buckets are interpreted as separate layers.
108142 for i in range (self .count ):
109- l1 = nn .Linear (2 * self .L1 // 2 , self .L2 + 1 )
110- l2 = nn .Linear (self .L2 * 2 , self .L3 )
111- output = nn .Linear (self .L3 , 1 )
112- l1 .weight .data = (
113- self .l1 .weight [i * (self .L2 + 1 ) : (i + 1 ) * (self .L2 + 1 ), :]
114- + self .l1_fact .weight .data
115- )
116- l1 .bias .data = (
117- self .l1 .bias [i * (self .L2 + 1 ) : (i + 1 ) * (self .L2 + 1 )]
118- + self .l1_fact .bias .data
119- )
120- l2 .weight .data = self .l2 .weight [i * self .L3 : (i + 1 ) * self .L3 , :]
121- l2 .bias .data = self .l2 .bias [i * self .L3 : (i + 1 ) * self .L3 ]
122- output .weight .data = self .output .weight [i : (i + 1 ), :]
123- output .bias .data = self .output .bias [i : (i + 1 )]
124- yield l1 , l2 , output
143+ yield self .l1 .at_index (i ), self .l2 .at_index (i ), self .output .at_index (i )
125144
126145 @torch .no_grad ()
127146 def coalesce_layer_stacks_inplace (self ) -> None :
128- # During training the buckets are represented by a single, wider, layer.
129- # This representation needs to be transformed into individual layers
130- # for the serializer, because the buckets are interpreted as separate layers.
131- for i in range (self .count ):
132- self .l1 .weight [i * (self .L2 + 1 ) : (i + 1 ) * (self .L2 + 1 ), :].add_ (
133- self .l1_fact .weight
134- )
135- self .l1 .bias [i * (self .L2 + 1 ) : (i + 1 ) * (self .L2 + 1 )].add_ (
136- self .l1_fact .bias
137- )
138-
139- self .l1_fact .weight .zero_ ()
140- self .l1_fact .bias .zero_ ()
147+ self .l1 .coalesce_weights ()
141148
142149
143150class NNUEModel (nn .Module ):
0 commit comments