4343
4444import data_loader
4545import model as M
46- from model import NNUE
47- from features . feature_set import FeatureSet
46+ from model import NNUE , NNUEModel , ModelConfig
47+ from features import FeatureSet
4848
4949
5050"""
@@ -341,14 +341,14 @@ def make_swaps_3(actmat, use_cupy=True):
341341 return cycles , total_improvement
342342
343343
344- def find_perm_impl (actmat , use_cupy ):
344+ def find_perm_impl (actmat , use_cupy , L1 : int ):
345345 actmat = np .reshape (actmat , (actmat .shape [0 ] * 2 , actmat .shape [1 ] // 2 ))
346346 if use_cupy :
347347 actmat = cp .asarray (actmat , dtype = cp .int8 )
348348 actmat_orig = actmat .copy ()
349349
350350 total_score_change = 0
351- perm = np .arange (M . L1 // 2 )
351+ perm = np .arange (L1 // 2 )
352352
353353 stages = [make_swaps_2 , make_swaps_3 ]
354354 # The optimization routines are deterministic, so no need to retry.
@@ -396,9 +396,9 @@ def find_perm_impl(actmat, use_cupy):
396396# -------------------------------------------------------------
397397
398398
399- def read_model (nnue_path , feature_set : FeatureSet ):
399+ def read_model (nnue_path , feature_set : FeatureSet , config : ModelConfig ):
400400 with open (nnue_path , "rb" ) as f :
401- reader = serialize .NNUEReader (f , feature_set )
401+ reader = serialize .NNUEReader (f , feature_set , config )
402402 return reader .model
403403
404404
@@ -441,12 +441,12 @@ def forward_ft(
441441 layer_stack_indices ,
442442):
443443 wp , bp = model .input (white_indices , white_values , black_indices , black_values )
444- w , wpsqt = torch .split (wp , M .L1 , dim = 1 )
445- b , bpsqt = torch .split (bp , M .L1 , dim = 1 )
444+ w , _ = torch .split (wp , model .L1 , dim = 1 )
445+ b , _ = torch .split (bp , model .L1 , dim = 1 )
446446 l0_ = (us * torch .cat ([w , b ], dim = 1 )) + (them * torch .cat ([b , w ], dim = 1 ))
447447 l0_ = torch .clamp (l0_ , 0.0 , 127.0 )
448448
449- l0_s = torch .split (l0_ , M .L1 // 2 , dim = 1 )
449+ l0_s = torch .split (l0_ , model .L1 // 2 , dim = 1 )
450450 l0_s1 = [l0_s [0 ] * l0_s [1 ], l0_s [2 ] * l0_s [3 ]]
451451 # We multiply by 127/128 because in the quantized network 1.0 is represented by 127
452452 # and it's more efficient to divide by 128 instead.
@@ -551,9 +551,11 @@ def gather_impl(model, dataset, count):
551551def command_gather (args ):
552552 feature_set = features .get_feature_set_from_name (args .features )
553553 if args .checkpoint :
554- model = NNUE .load_from_checkpoint (args .checkpoint , feature_set = feature_set )
554+ model = NNUE .load_from_checkpoint (
555+ args .checkpoint , feature_set = feature_set , config = ModelConfig (L1 = args .l1 )
556+ )
555557 else :
556- model = read_model (args .net , feature_set )
558+ model = read_model (args .net , feature_set , ModelConfig ( L1 = args . l1 ) )
557559
558560 model .eval ()
559561
@@ -600,13 +602,13 @@ def command_find_perm(args):
600602
601603 perm = find_perm_impl (actmat , args .use_cupy )
602604
603- # perm = np.random.permutation([i for i in range(M. L1)])
605+ # perm = np.random.permutation([i for i in range(L1)])
604606 with open (args .out , "wb" ) as file :
605607 np .save (file , perm )
606608
607609
608610def ft_optimize (
609- model ,
611+ model : NNUEModel ,
610612 dataset_path ,
611613 count ,
612614 actmat_save_path = None ,
@@ -620,7 +622,7 @@ def ft_optimize(
620622 np .save (file , actmat )
621623
622624 print ("Finding permutation..." )
623- perm = find_perm_impl (actmat , use_cupy )
625+ perm = find_perm_impl (actmat , use_cupy , model . L1 )
624626 if actmat_save_path is not None :
625627 with open (perm_save_path , "wb" ) as file :
626628 np .save (file , perm )
@@ -666,6 +668,7 @@ def main():
666668 parser_gather .add_argument (
667669 "--out" , type = str , help = "Filename under which to save the resulting ft matrix"
668670 )
671+ parser_gather .add_argument ("--l1" , type = int , default = M .ModelConfig ().L1 )
669672 features .add_argparse_args (parser_gather )
670673 parser_gather .set_defaults (func = command_gather )
671674
0 commit comments