@@ -98,30 +98,46 @@ def forward(self, x: Tensor, ls_indices: Tensor):
9898
9999 return l3x_
100100
101+ @torch .no_grad ()
101102 def get_coalesced_layer_stacks (
102103 self ,
103104 ) -> Generator [tuple [nn .Linear , nn .Linear , nn .Linear ], None , None ]:
104105 # During training the buckets are represented by a single, wider, layer.
105106 # This representation needs to be transformed into individual layers
106107 # for the serializer, because the buckets are interpreted as separate layers.
107108 for i in range (self .count ):
108- with torch .no_grad ():
109- l1 = nn .Linear (2 * self .L1 // 2 , self .L2 + 1 )
110- l2 = nn .Linear (self .L2 * 2 , self .L3 )
111- output = nn .Linear (self .L3 , 1 )
112- l1 .weight .data = (
113- self .l1 .weight [i * (self .L2 + 1 ) : (i + 1 ) * (self .L2 + 1 ), :]
114- + self .l1_fact .weight .data
115- )
116- l1 .bias .data = (
117- self .l1 .bias [i * (self .L2 + 1 ) : (i + 1 ) * (self .L2 + 1 )]
118- + self .l1_fact .bias .data
119- )
120- l2 .weight .data = self .l2 .weight [i * self .L3 : (i + 1 ) * self .L3 , :]
121- l2 .bias .data = self .l2 .bias [i * self .L3 : (i + 1 ) * self .L3 ]
122- output .weight .data = self .output .weight [i : (i + 1 ), :]
123- output .bias .data = self .output .bias [i : (i + 1 )]
124- yield l1 , l2 , output
109+ l1 = nn .Linear (2 * self .L1 // 2 , self .L2 + 1 )
110+ l2 = nn .Linear (self .L2 * 2 , self .L3 )
111+ output = nn .Linear (self .L3 , 1 )
112+ l1 .weight .data = (
113+ self .l1 .weight [i * (self .L2 + 1 ) : (i + 1 ) * (self .L2 + 1 ), :]
114+ + self .l1_fact .weight .data
115+ )
116+ l1 .bias .data = (
117+ self .l1 .bias [i * (self .L2 + 1 ) : (i + 1 ) * (self .L2 + 1 )]
118+ + self .l1_fact .bias .data
119+ )
120+ l2 .weight .data = self .l2 .weight [i * self .L3 : (i + 1 ) * self .L3 , :]
121+ l2 .bias .data = self .l2 .bias [i * self .L3 : (i + 1 ) * self .L3 ]
122+ output .weight .data = self .output .weight [i : (i + 1 ), :]
123+ output .bias .data = self .output .bias [i : (i + 1 )]
124+ yield l1 , l2 , output
125+
126+ @torch .no_grad ()
127+ def coalesce_layer_stacks_inplace (self ) -> None :
128+ # During training the buckets are represented by a single, wider, layer.
129+ # This representation needs to be transformed into individual layers
130+ # for the serializer, because the buckets are interpreted as separate layers.
131+ for i in range (self .count ):
132+ self .l1 .weight [i * (self .L2 + 1 ) : (i + 1 ) * (self .L2 + 1 ), :].add_ (
133+ self .l1_fact .weight
134+ )
135+ self .l1 .bias [i * (self .L2 + 1 ) : (i + 1 ) * (self .L2 + 1 )].add_ (
136+ self .l1_fact .bias
137+ )
138+
139+ self .l1_fact .weight .zero_ ()
140+ self .l1_fact .bias .zero_ ()
125141
126142
127143class NNUEModel (nn .Module ):
0 commit comments