88import lightning as L
99from dataclasses import dataclass
1010from features .feature_set import FeatureSet
11- from typing import List , Tuple
1211
1312# 3 layer fully connected network
1413L1 = 3072
@@ -43,10 +42,6 @@ def __init__(self, count: int):
4342 self .l2 = nn .Linear (L2 * 2 , L3 * count )
4443 self .output = nn .Linear (L3 , 1 * count )
4544
46- # Cached helper tensor for choosing outputs by bucket indices.
47- # Initialized lazily in forward.
48- self .idx_offset = None
49-
5045 self ._init_layers ()
5146
5247 def _init_layers (self ):
@@ -83,9 +78,14 @@ def _init_layers(self):
8378 self .output .bias = nn .Parameter (output_bias )
8479
8580 def forward (self , x : Tensor , ls_indices : Tensor ):
86- assert self .idx_offset is not None and self .idx_offset .shape [0 ] == x .shape [0 ]
81+ idx_offset = torch .arange (
82+ 0 ,
83+ x .shape [0 ] * self .count ,
84+ self .count ,
85+ device = x .device
86+ )
8787
88- indices = ls_indices .flatten () + self . idx_offset
88+ indices = ls_indices .flatten () + idx_offset
8989
9090 l1s_ = self .l1 (x ).reshape ((- 1 , self .count , L2 + 1 ))
9191 l1f_ = self .l1_fact (x )
@@ -135,45 +135,22 @@ def get_coalesced_layer_stacks(self):
135135 yield l1 , l2 , output
136136
137137
138- class NNUE (L .LightningModule ):
139- """
140- feature_set - an instance of FeatureSet defining the input features
141-
142- lambda_ = 0.0 - purely based on game results
143- 0.0 < lambda_ < 1.0 - interpolated score and result
144- lambda_ = 1.0 - purely based on search scores
145-
146- gamma - the multiplicative factor applied to the learning rate after each epoch
147-
148- lr - the initial learning rate
149- """
150-
138+ class NNUEModel (nn .Module ):
151139 def __init__ (
152140 self ,
153141 feature_set : FeatureSet ,
154- max_epoch = 800 ,
155- num_batches_per_epoch = int (100_000_000 / 16384 ),
156- gamma = 0.992 ,
157- lr = 8.75e-4 ,
158- param_index = 0 ,
159- num_psqt_buckets = 8 ,
160- num_ls_buckets = 8 ,
161- loss_params = LossParams (),
142+ num_psqt_buckets : int = 8 ,
143+ num_ls_buckets : int = 8 ,
162144 ):
163- super (NNUE , self ).__init__ ()
145+ super ().__init__ ()
164146 self .num_psqt_buckets = num_psqt_buckets
165147 self .num_ls_buckets = num_ls_buckets
148+
166149 self .input = DoubleFeatureTransformerSlice (
167150 feature_set .num_features , L1 + self .num_psqt_buckets
168151 )
169152 self .feature_set = feature_set
170153 self .layer_stacks = LayerStacks (self .num_ls_buckets )
171- self .loss_params = loss_params
172- self .max_epoch = max_epoch
173- self .num_batches_per_epoch = num_batches_per_epoch
174- self .gamma = gamma
175- self .lr = lr
176- self .param_index = param_index
177154
178155 self .nnue2score = 600.0
179156 self .weight_scale_hidden = 64.0
@@ -205,22 +182,21 @@ def __init__(
205182
206183 self ._init_layers ()
207184
208- """
209- We zero all virtual feature weights because there's not need for them
210- to be initialized; they only aid the training of correlated features.
211- """
185+ def _init_layers (self ):
186+ self ._zero_virtual_feature_weights ()
187+ self ._init_psqt ()
212188
213189 def _zero_virtual_feature_weights (self ):
190+ """
191+ We zero all virtual feature weights because there's not need for them
192+ to be initialized; they only aid the training of correlated features.
193+ """
214194 weights = self .input .weight
215195 with torch .no_grad ():
216196 for a , b in self .feature_set .get_virtual_feature_ranges ():
217197 weights [a :b , :] = 0.0
218198 self .input .weight = nn .Parameter (weights )
219199
220- def _init_layers (self ):
221- self ._zero_virtual_feature_weights ()
222- self ._init_psqt ()
223-
224200 def _init_psqt (self ):
225201 input_weights = self .input .weight
226202 input_bias = self .input .bias
@@ -251,12 +227,11 @@ def _init_psqt(self):
251227 self .input .weight = nn .Parameter (input_weights )
252228 self .input .bias = nn .Parameter (input_bias )
253229
254- """
255- Clips the weights of the model based on the min/max values allowed
256- by the quantization scheme.
257- """
258-
259230 def _clip_weights (self ):
231+ """
232+ Clips the weights of the model based on the min/max values allowed
233+ by the quantization scheme.
234+ """
260235 for group in self .weight_clipping :
261236 for p in group ["params" ]:
262237 if "min_weight" in group or "max_weight" in group :
@@ -287,12 +262,11 @@ def _clip_weights(self):
287262 raise Exception ("Not supported." )
288263 p .data .copy_ (p_data_fp32 )
289264
290- """
291- This method attempts to convert the model from using the self.feature_set
292- to new_feature_set. Currently only works for adding virtual features.
293- """
294-
295265 def set_feature_set (self , new_feature_set : FeatureSet ):
266+ """
267+ This method attempts to convert the model from using the self.feature_set
268+ to new_feature_set. Currently only works for adding virtual features.
269+ """
296270 if self .feature_set .name == new_feature_set .name :
297271 return
298272
@@ -370,13 +344,51 @@ def forward(
370344
371345 return x
372346
373- def step_ (self , batch : Tuple [Tensor , ...], batch_idx , loss_type ):
347+
348+ class NNUE (L .LightningModule ):
349+ """
350+ feature_set - an instance of FeatureSet defining the input features
351+
352+ lambda_ = 0.0 - purely based on game results
353+ 0.0 < lambda_ < 1.0 - interpolated score and result
354+ lambda_ = 1.0 - purely based on search scores
355+
356+ gamma - the multiplicative factor applied to the learning rate after each epoch
357+
358+ lr - the initial learning rate
359+ """
360+
361+ def __init__ (
362+ self ,
363+ feature_set : FeatureSet ,
364+ max_epoch = 800 ,
365+ num_batches_per_epoch = int (100_000_000 / 16384 ),
366+ gamma = 0.992 ,
367+ lr = 8.75e-4 ,
368+ param_index = 0 ,
369+ num_psqt_buckets = 8 ,
370+ num_ls_buckets = 8 ,
371+ loss_params = LossParams (),
372+ ):
373+ super ().__init__ ()
374+ self .model : NNUEModel = NNUEModel (feature_set , num_psqt_buckets , num_ls_buckets )
375+ self .loss_params = loss_params
376+ self .max_epoch = max_epoch
377+ self .num_batches_per_epoch = num_batches_per_epoch
378+ self .gamma = gamma
379+ self .lr = lr
380+ self .param_index = param_index
381+
382+ def forward (self , * args , ** kwargs ):
383+ return self .model (* args , ** kwargs )
384+
385+ def step_ (self , batch : tuple [Tensor , ...], batch_idx , loss_type ):
374386 _ = batch_idx # unused, but required by pytorch-lightning
375387
376388 # We clip weights at the start of each step. This means that after
377389 # the last step the weights might be outside of the desired range.
378390 # They should be also clipped accordingly in the serializer.
379- self ._clip_weights ()
391+ self .model . _clip_weights ()
380392
381393 (
382394 us ,
@@ -392,7 +404,7 @@ def step_(self, batch: Tuple[Tensor, ...], batch_idx, loss_type):
392404 ) = batch
393405
394406 scorenet = (
395- self (
407+ self . model (
396408 us ,
397409 them ,
398410 white_indices ,
@@ -402,7 +414,7 @@ def step_(self, batch: Tuple[Tensor, ...], batch_idx, loss_type):
402414 psqt_indices ,
403415 layer_stack_indices ,
404416 )
405- * self .nnue2score
417+ * self .model . nnue2score
406418 )
407419
408420 p = self .loss_params
@@ -445,15 +457,15 @@ def test_step(self, batch, batch_idx):
445457 def configure_optimizers (self ):
446458 LR = self .lr
447459 train_params = [
448- {"params" : get_parameters ([self .input ]), "lr" : LR , "gc_dim" : 0 },
449- {"params" : [self .layer_stacks .l1_fact .weight ], "lr" : LR },
450- {"params" : [self .layer_stacks .l1_fact .bias ], "lr" : LR },
451- {"params" : [self .layer_stacks .l1 .weight ], "lr" : LR },
452- {"params" : [self .layer_stacks .l1 .bias ], "lr" : LR },
453- {"params" : [self .layer_stacks .l2 .weight ], "lr" : LR },
454- {"params" : [self .layer_stacks .l2 .bias ], "lr" : LR },
455- {"params" : [self .layer_stacks .output .weight ], "lr" : LR },
456- {"params" : [self .layer_stacks .output .bias ], "lr" : LR },
460+ {"params" : get_parameters ([self .model . input ]), "lr" : LR , "gc_dim" : 0 },
461+ {"params" : [self .model . layer_stacks .l1_fact .weight ], "lr" : LR },
462+ {"params" : [self .model . layer_stacks .l1_fact .bias ], "lr" : LR },
463+ {"params" : [self .model . layer_stacks .l1 .weight ], "lr" : LR },
464+ {"params" : [self .model . layer_stacks .l1 .bias ], "lr" : LR },
465+ {"params" : [self .model . layer_stacks .l2 .weight ], "lr" : LR },
466+ {"params" : [self .model . layer_stacks .l2 .bias ], "lr" : LR },
467+ {"params" : [self .model . layer_stacks .output .weight ], "lr" : LR },
468+ {"params" : [self .model . layer_stacks .output .bias ], "lr" : LR },
457469 ]
458470
459471 optimizer = ranger21 .Ranger21 (
@@ -479,7 +491,7 @@ def configure_optimizers(self):
479491 return [optimizer ], [scheduler ]
480492
481493
482- def coalesce_ft_weights (model : NNUE , layer : BaseFeatureTransformerSlice ):
494+ def coalesce_ft_weights (model : NNUEModel , layer : BaseFeatureTransformerSlice ):
483495 weight = layer .weight .data
484496 indices = model .feature_set .get_virtual_to_real_features_gather_indices ()
485497 weight_coalesced = weight .new_zeros (
@@ -492,5 +504,5 @@ def coalesce_ft_weights(model: NNUE, layer: BaseFeatureTransformerSlice):
492504 return weight_coalesced
493505
494506
495- def get_parameters (layers : List [nn .Module ]):
507+ def get_parameters (layers : list [nn .Module ]):
496508 return [p for layer in layers for p in layer .parameters ()]
0 commit comments