3131
3232"""
3333
34- import time
3534import argparse
36- import chess
37- import torch
3835import copy
36+ from dataclasses import dataclass
37+ import time
38+ from typing import Callable , Generator , TypeAlias
39+
40+ import chess
3941import cupy as cp
4042import numpy as np
43+ import numpy .typing as npt
44+ import torch
4145
4246import data_loader
4347import model as M
6165VERBOSE = False
6266
6367
64- def batched (arr , batch_size ) :
68+ def batched (arr : npt . NDArray , batch_size : int ) -> Generator [ npt . NDArray , None , None ] :
6569 """
6670 Utility generator that yields chunks of array `arr` of size `batch_size`
6771 Expects arr to be a numpy-like array
@@ -73,14 +77,14 @@ def batched(arr, batch_size):
7377 idx += batch_size
7478
7579
76- def apply_swap (perm , i , j ) :
80+ def apply_swap (perm : npt . NDArray , i : int , j : int ) -> None :
7781 """
7882 Swap `i`-th and `j`-th elements in the array `perm`.
7983 """
8084 perm [i ], perm [j ] = perm [j ], perm [i ]
8185
8286
83- def apply_rotate_right (perm , indices ) :
87+ def apply_rotate_right (perm : npt . NDArray , indices : tuple [ int , ...]) -> None :
8488 """
8589 Rotates right the values in `perm` at selected indices `indices`.
8690 The rotation is performed as-if the selected indices were layed out in the order
@@ -92,7 +96,9 @@ def apply_rotate_right(perm, indices):
9296 perm [i ] = j
9397
9498
95- def get_swapped_zero_positive_count (actmat_flat , use_cupy = True ):
99+ def get_swapped_zero_positive_count (
100+ actmat_flat : npt .NDArray [np .bool_ ], use_cupy : bool = True
101+ ) -> int :
96102 if use_cupy :
97103 actmat_flat = cp .asarray (actmat_flat , dtype = cp .int8 )
98104
@@ -148,7 +154,9 @@ def get_swapped_zero_positive_count(actmat_flat, use_cupy=True):
148154 return swapped_zero_count
149155
150156
151- def get_swapped_zero_increase (actmat , use_cupy = True ):
157+ def get_swapped_zero_increase (
158+ actmat : npt .NDArray [np .bool_ ], use_cupy : bool = True
159+ ) -> npt .NDArray [np .int_ ]:
152160 n_neurons = actmat .shape [1 ]
153161 swapped_zero_count = 0
154162
@@ -179,7 +187,9 @@ def get_swapped_zero_increase(actmat, use_cupy=True):
179187 return swapped_zero_increase
180188
181189
182- def get_score_change (actmat , use_cupy = True ):
190+ def get_score_change (
191+ actmat : npt .NDArray [np .bool_ ], use_cupy : bool = True
192+ ) -> npt .NDArray [np .int_ ]:
183193 # actmat is a boolean matrix of shape (N, L1) with "True" meaning 0
184194
185195 n_neurons = actmat .shape [1 ]
@@ -193,7 +203,16 @@ def get_score_change(actmat, use_cupy=True):
193203 return score_change
194204
195205
196- def make_swaps_2 (actmat , use_cupy = True ):
206+ @dataclass
207+ class SwapResult :
208+ swaps : list [tuple [int , ...]]
209+ score_change : float
210+
211+
212+ SwapFucntion : TypeAlias = Callable [[npt .NDArray [np .bool_ ], bool ], SwapResult ]
213+
214+
215+ def make_swaps_2 (actmat : npt .NDArray [np .bool_ ], use_cupy : bool = True ) -> SwapResult :
197216 """
198217 Returns a series of independent 2-swap operations that collectively improve the objective function.
199218 """
@@ -210,7 +229,7 @@ def make_swaps_2(actmat, use_cupy=True):
210229 # Sum score_change[i, j] + score_change[j, i] to get the cumulative impact of the swap.
211230 score_change = score_change + score_change .T
212231
213- def all_indices_in_same_block (i ) :
232+ def all_indices_in_same_block (i : int ) -> list [ int ] :
214233 """Returns a list of indices of all neurons in the same block as the i-th neuron."""
215234 # Floor to the start of the block.
216235 base = i // ZERO_BLOCK_SIZE * ZERO_BLOCK_SIZE
@@ -248,10 +267,10 @@ def all_indices_in_same_block(i):
248267 print (f"Time elapsed: { time .time () - start_time :0.3f} " )
249268 print (f"Improvement this iteration: { total_improvement :0.3f} " )
250269
251- return swaps , total_improvement
270+ return SwapResult ( swaps , total_improvement )
252271
253272
254- def make_swaps_3 (actmat , use_cupy = True ):
273+ def make_swaps_3 (actmat : npt . NDArray [ np . bool_ ] , use_cupy : bool = True ) -> SwapResult :
255274 """
256275 Returns a series of independent left-rotates operations that collectively improve the objective function.
257276 """
@@ -340,10 +359,12 @@ def make_swaps_3(actmat, use_cupy=True):
340359 total_improvement = total_score_change / n_samples / (n_neurons // 4 ) * 100
341360 print (f"Time elapsed: { time .time () - start_time :0.3f} " )
342361 print (f"Improvement this iteration: { total_improvement :0.3f} " )
343- return cycles , total_improvement
362+ return SwapResult ( cycles , total_improvement )
344363
345364
346- def find_perm_impl (actmat , use_cupy , L1 : int ):
365+ def find_perm_impl (
366+ actmat : npt .NDArray [np .bool_ ], use_cupy : bool , L1 : int
367+ ) -> npt .NDArray [np .int_ ]:
347368 actmat = np .reshape (actmat , (actmat .shape [0 ] * 2 , actmat .shape [1 ] // 2 ))
348369 if use_cupy :
349370 actmat = cp .asarray (actmat , dtype = cp .int8 )
@@ -352,7 +373,7 @@ def find_perm_impl(actmat, use_cupy, L1: int):
352373 total_score_change = 0
353374 perm = np .arange (L1 // 2 )
354375
355- stages = [make_swaps_2 , make_swaps_3 ]
376+ stages : list [ SwapFucntion ] = [make_swaps_2 , make_swaps_3 ]
356377 # The optimization routines are deterministic, so no need to retry.
357378 stages_max_fails = [0 , 0 ]
358379 stage_id = 0
@@ -370,15 +391,15 @@ def find_perm_impl(actmat, use_cupy, L1: int):
370391
371392 # Calculate a set of independent right rotates (so swaps for 2 element case)
372393 # that when applied improve the objective function
373- swaps , score_change = swap_fn (actmat , use_cupy )
374- for cycle in swaps :
394+ swap_result = swap_fn (actmat , use_cupy )
395+ for cycle in swap_result . swaps :
375396 # Update the current best permutation with the newly found adjustments.
376397 apply_rotate_right (perm , cycle )
377398
378- total_score_change += score_change
399+ total_score_change += swap_result . score_change
379400 print (f"Total improvement: { total_score_change } \n " )
380401
381- if score_change == 0 :
402+ if swap_result . score_change == 0 :
382403 num_fails += 1
383404 if num_fails > stages_max_fails [stage_id ]:
384405 num_fails = 0
@@ -399,17 +420,19 @@ def find_perm_impl(actmat, use_cupy, L1: int):
399420
400421
401422def read_model (
402- nnue_path ,
423+ nnue_path : str ,
403424 feature_set : FeatureSet ,
404425 config : ModelConfig ,
405426 quantize_config : QuantizationConfig ,
406- ):
427+ ) -> NNUEModel :
407428 with open (nnue_path , "rb" ) as f :
408429 reader = NNUEReader (f , feature_set , config , quantize_config )
409430 return reader .model
410431
411432
412- def make_fen_batch_provider (data_path , batch_size ):
433+ def make_fen_batch_provider (
434+ data_path : str , batch_size : int
435+ ) -> data_loader .FenBatchProvider :
413436 return data_loader .FenBatchProvider (
414437 data_path ,
415438 True ,
@@ -421,7 +444,7 @@ def make_fen_batch_provider(data_path, batch_size):
421444 )
422445
423446
424- def filter_fens (fens ) :
447+ def filter_fens (fens : list [ str ]) -> list [ str ] :
425448 # We don't want fens where a king is in check, as these cannot be evaluated by the engine.
426449 filtered_fens = []
427450 for fen in fens :
@@ -431,7 +454,7 @@ def filter_fens(fens):
431454 return filtered_fens
432455
433456
434- def quantize_ft (model : NNUEModel ):
457+ def quantize_ft (model : NNUEModel ) -> None :
435458 model .input .weight .data = model .input .weight .data .mul (
436459 model .quantization .quantized_one
437460 ).round ()
@@ -441,16 +464,16 @@ def quantize_ft(model: NNUEModel):
441464
442465
443466def forward_ft (
444- model ,
445- us ,
446- them ,
447- white_indices ,
448- white_values ,
449- black_indices ,
450- black_values ,
451- psqt_indices ,
452- layer_stack_indices ,
453- ):
467+ model : NNUEModel ,
468+ us : torch . Tensor ,
469+ them : torch . Tensor ,
470+ white_indices : torch . Tensor ,
471+ white_values : torch . Tensor ,
472+ black_indices : torch . Tensor ,
473+ black_values : torch . Tensor ,
474+ psqt_indices : torch . Tensor ,
475+ layer_stack_indices : torch . Tensor ,
476+ ) -> torch . Tensor :
454477 wp , bp = model .input (white_indices , white_values , black_indices , black_values )
455478 w , _ = torch .split (wp , model .L1 , dim = 1 )
456479 b , _ = torch .split (bp , model .L1 , dim = 1 )
@@ -466,7 +489,7 @@ def forward_ft(
466489 return l0_ .round ()
467490
468491
469- def eval_ft (model , batch : data_loader .SparseBatchPtr ):
492+ def eval_ft (model : NNUEModel , batch : data_loader .SparseBatchPtr ) -> torch . Tensor :
470493 with torch .no_grad ():
471494 (
472495 us ,
@@ -494,8 +517,8 @@ def eval_ft(model, batch: data_loader.SparseBatchPtr):
494517 return res
495518
496519
497- def ft_permute_impl (model , permutation ) :
498- permutation = list (permutation )
520+ def ft_permute_impl (model : NNUEModel , perm : npt . NDArray [ np . int_ ]) -> None :
521+ permutation = list (perm )
499522
500523 l1_size = model .layer_stacks .l1 .in_features
501524 if l1_size != len (permutation ) * 2 :
@@ -517,14 +540,14 @@ def ft_permute_impl(model, permutation):
517540 ]
518541
519542
520- def ft_permute (model , ft_perm_path ) :
543+ def ft_permute (model : NNUEModel , ft_perm_path : str ) -> None :
521544 with open (ft_perm_path , "rb" ) as f :
522545 permutation = np .load (f )
523546
524547 ft_permute_impl (model , permutation )
525548
526549
527- def gather_impl (model : NNUEModel , dataset , count ) :
550+ def gather_impl (model : NNUEModel , dataset : str , count : int ) -> npt . NDArray [ np . bool_ ] :
528551 ZERO_POINT = 0.0 # Vary this to check hypothetical forced larger truncation to zero
529552 BATCH_SIZE = 1000
530553
@@ -559,7 +582,7 @@ def gather_impl(model: NNUEModel, dataset, count):
559582 return np .concatenate (actmats , axis = 0 )
560583
561584
562- def command_gather (args ) :
585+ def command_gather (args : argparse . Namespace ) -> None :
563586 feature_set = M .get_feature_set_from_name (args .features )
564587 if args .checkpoint :
565588 nnue = NNUE .load_from_checkpoint (
@@ -582,13 +605,15 @@ def command_gather(args):
582605 np .save (file , actmat )
583606
584607
585- def eval_act_mat (actmat ) :
608+ def eval_act_mat (actmat : npt . NDArray [ np . bool_ ]) -> float :
586609 actmat = actmat .reshape ((actmat .shape [0 ], actmat .shape [1 ] // 4 , 4 ))
587610 r = np .all (actmat , axis = 2 )
588611 return np .count_nonzero (r ) / r .shape [0 ] / r .shape [1 ]
589612
590613
591- def eval_perm_impl (actmat , perm = None ):
614+ def eval_perm_impl (
615+ actmat : npt .NDArray [np .bool_ ], perm : npt .NDArray [np .int_ ] | None = None
616+ ) -> None :
592617 actmat = np .reshape (actmat , (actmat .shape [0 ] * 2 , actmat .shape [1 ] // 2 ))
593618
594619 actmat_eval = eval_act_mat (actmat )
@@ -600,7 +625,7 @@ def eval_perm_impl(actmat, perm=None):
600625 print (f"Combined zeros in perm matrix: { perm_act_mat_eval * 100 :0.6f} " )
601626
602627
603- def command_eval_perm (args ) :
628+ def command_eval_perm (args : argparse . Namespace ) -> None :
604629 with open (args .data , "rb" ) as file :
605630 actmat = np .load (file )
606631
@@ -613,7 +638,7 @@ def command_eval_perm(args):
613638 eval_perm_impl (actmat , perm )
614639
615640
616- def command_find_perm (args ) :
641+ def command_find_perm (args : argparse . Namespace ) -> None :
617642 with open (args .data , "rb" ) as file :
618643 actmat = np .load (file )
619644
@@ -626,12 +651,12 @@ def command_find_perm(args):
626651
627652def ft_optimize (
628653 model : NNUEModel ,
629- dataset_path ,
630- count ,
631- actmat_save_path = None ,
632- perm_save_path = None ,
633- use_cupy = True ,
634- ):
654+ dataset_path : str ,
655+ count : int ,
656+ actmat_save_path : str | None = None ,
657+ perm_save_path : str | None = None ,
658+ use_cupy : bool = True ,
659+ ) -> None :
635660 print ("Gathering activation data..." )
636661 actmat = gather_impl (model , dataset_path , count )
637662 if actmat_save_path is not None :
@@ -651,12 +676,12 @@ def ft_optimize(
651676 ft_permute_impl (model , perm )
652677
653678
654- def set_cupy_device (device ) :
679+ def set_cupy_device (device : int ) -> None :
655680 if device is not None :
656681 cp .cuda .runtime .setDevice (device )
657682
658683
659- def main ():
684+ def main () -> None :
660685 parser = argparse .ArgumentParser (description = "" )
661686 parser .add_argument (
662687 "--no-cupy" ,
0 commit comments