Skip to content

Commit 00bdf75

Browse files
authored
Move NNUE-related code to model (official-stockfish#351)
* move features to model * move NNUEReader, NNUEWriter, and load_model to model.utils * fix * fix * ruff format * revert serialize renaming * remove FeatureSet dependency in data_loader * fix * revert description changes * fix nnue reader * .ckpt
1 parent a7beb37 commit 00bdf75

23 files changed

+472
-471
lines changed

cross_check_eval.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,20 @@
44

55
import chess
66

7-
import features
8-
import serialize
97
import data_loader
10-
from model import NNUE, ModelConfig
11-
from features import FeatureSet
8+
from model import (
9+
add_feature_args,
10+
FeatureSet,
11+
get_feature_set_from_name,
12+
NNUE,
13+
NNUEReader,
14+
ModelConfig,
15+
)
1216

1317

1418
def read_model(nnue_path, feature_set: FeatureSet, config: ModelConfig):
1519
with open(nnue_path, "rb") as f:
16-
reader = serialize.NNUEReader(f, feature_set, config)
20+
reader = NNUEReader(f, feature_set, config)
1721
return reader.model
1822

1923

@@ -165,12 +169,12 @@ def main():
165169
"--count", type=int, default=100, help="number of datapoints to process"
166170
)
167171
parser.add_argument("--l1", type=int, default=ModelConfig().L1)
168-
features.add_argparse_args(parser)
172+
add_feature_args(parser)
169173
args = parser.parse_args()
170174

171175
batch_size = 1000
172176

173-
feature_set = features.get_feature_set_from_name(args.features)
177+
feature_set = get_feature_set_from_name(args.features)
174178
if args.checkpoint:
175179
model = NNUE.load_from_checkpoint(
176180
args.checkpoint, feature_set=feature_set, config=ModelConfig(L1=args.l1)
@@ -189,7 +193,7 @@ def main():
189193
fens = filter_fens(next(fen_batch_provider))
190194

191195
b = data_loader.get_sparse_batch_from_fens(
192-
feature_set, fens, [0] * len(fens), [1] * len(fens), [0] * len(fens)
196+
feature_set.name, fens, [0] * len(fens), [1] * len(fens), [0] * len(fens)
193197
)
194198
model_evals += eval_model_batch(model, b)
195199
data_loader.destroy_sparse_batch(b)

data_loader/stream.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from ._native import c_lib, SparseBatchPtr, FenBatchPtr
44
from .config import CDataloaderSkipConfig, DataloaderSkipConfig
5-
from features import FeatureSet
65

76

87
def _to_c_str_array(str_list):
@@ -64,15 +63,15 @@ def destroy_sparse_batch_stream(stream: ctypes.c_void_p):
6463

6564

6665
def get_sparse_batch_from_fens(
67-
feature_set: FeatureSet, fens, scores, plies, results
66+
feature_set: str, fens, scores, plies, results
6867
) -> SparseBatchPtr:
6968
assert len(fens) == len(scores) == len(plies) == len(results)
7069

7170
def to_c_int_array(data):
7271
return (ctypes.c_int * len(data))(*data)
7372

7473
return c_lib.dll.get_sparse_batch_from_fens(
75-
feature_set.name.encode("utf-8"),
74+
feature_set.encode("utf-8"),
7675
len(fens),
7776
_to_c_str_array(fens),
7877
to_c_int_array(scores),

docs/features.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@ The current semantics are as follows:
1212
2. When resuming training from a .pt model - `--features` specifies the feature set to use for learning. If the feature set specified doesn't match the feature set from the .pt model a conversion is attempted. Right now only a conversion of feature set with a single block from non-factorized to factorized is supported. The factorized block must have the non-factorized features as the first factor. The virtual feature weights are initialized to zero.
1313
3. When converting .ckpt to .nnue - `--features` specifies the features as stored in the .ckpt file. The user must pass the correct feature set through `--features` because it can't be inferred from the .ckpt. If the features from `--features` and the saved model don't match it'll likely stack trace on some dimension mismatch.
1414
4. When converting .pt to .nnue - `--features` is ignored, the `feature_set` from the saved model is used, the weights are coalesced when writing the .nnue file.
15-
5. When converting .nnue to .pt - `--features` specifies the features in the .nnue file. The resulting .pt model has the same feature_set. Note that when resuming training this model can be converted to a compatible feature_set, see point 2.
15+
5. When converting .nnue to .pt - `--features` specifies the features in the .nnue file. The resulting .pt model has the same feature_set. Note that when resuming training this model can be converted to a compatible feature_set, see point 2.

ftperm.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,6 @@
3333

3434
import time
3535
import argparse
36-
import features
37-
import serialize
3836
import chess
3937
import torch
4038
import copy
@@ -43,8 +41,7 @@
4341

4442
import data_loader
4543
import model as M
46-
from model import NNUE, NNUEModel, ModelConfig
47-
from features import FeatureSet
44+
from model import FeatureSet, NNUE, NNUEModel, NNUEReader, ModelConfig
4845

4946

5047
"""
@@ -398,7 +395,7 @@ def find_perm_impl(actmat, use_cupy, L1: int):
398395

399396
def read_model(nnue_path, feature_set: FeatureSet, config: ModelConfig):
400397
with open(nnue_path, "rb") as f:
401-
reader = serialize.NNUEReader(f, feature_set, config)
398+
reader = NNUEReader(f, feature_set, config)
402399
return reader.model
403400

404401

@@ -531,7 +528,7 @@ def gather_impl(model, dataset, count):
531528
fens = filter_fens(next(fen_batch_provider))
532529

533530
b = data_loader.get_sparse_batch_from_fens(
534-
quantized_model.feature_set,
531+
quantized_model.feature_set.name,
535532
fens,
536533
[0] * len(fens),
537534
[1] * len(fens),
@@ -549,7 +546,7 @@ def gather_impl(model, dataset, count):
549546

550547

551548
def command_gather(args):
552-
feature_set = features.get_feature_set_from_name(args.features)
549+
feature_set = M.get_feature_set_from_name(args.features)
553550
if args.checkpoint:
554551
model = NNUE.load_from_checkpoint(
555552
args.checkpoint, feature_set=feature_set, config=ModelConfig(L1=args.l1)
@@ -669,7 +666,7 @@ def main():
669666
"--out", type=str, help="Filename under which to save the resulting ft matrix"
670667
)
671668
parser_gather.add_argument("--l1", type=int, default=M.ModelConfig().L1)
672-
features.add_argparse_args(parser_gather)
669+
M.add_feature_args(parser_gather)
673670
parser_gather.set_defaults(func=command_gather)
674671

675672
parser_gather = subparsers.add_parser("find_perm", help="a help")

model/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,22 @@
11
from .callbacks import WeightClippingCallback
22
from .config import ModelConfig, LossParams
3+
from .features import add_feature_args, FeatureSet, get_feature_set_from_name
34
from .lightning_module import NNUE
45
from .model import NNUEModel
5-
from .utils import coalesce_ft_weights
6+
from .utils import coalesce_ft_weights, load_model, NNUEReader, NNUEWriter
67

78

89
__all__ = [
910
"WeightClippingCallback",
1011
"ModelConfig",
1112
"LossParams",
13+
"add_feature_args",
14+
"FeatureSet",
15+
"get_feature_set_from_name",
1216
"NNUE",
1317
"NNUEModel",
1418
"coalesce_ft_weights",
19+
"load_model",
20+
"NNUEReader",
21+
"NNUEWriter",
1522
]
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def get_feature_blocks_from_names(names):
3232
return [_feature_blocks_by_name[name] for name in names]
3333

3434

35-
def get_feature_set_from_name(name):
35+
def get_feature_set_from_name(name) -> FeatureSet:
3636
feature_block_names = name.split("+")
3737
blocks = get_feature_blocks_from_names(feature_block_names)
3838
return FeatureSet(blocks)
@@ -42,7 +42,7 @@ def get_available_feature_blocks_names():
4242
return list(iter(_feature_blocks_by_name))
4343

4444

45-
def add_argparse_args(parser):
45+
def add_feature_args(parser):
4646
_default_feature_set_name = "HalfKAv2_hm^"
4747
parser.add_argument(
4848
"--features",
Lines changed: 27 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,13 @@ def __init__(self, features):
3434
)
3535
self.num_features = sum(feature.num_features for feature in features)
3636

37-
"""
38-
This method returns the feature ranges for the virtual factors of the
39-
underlying feature blocks. This is useful to know during initialization,
40-
when we want to zero initialize the virtual feature weights, but give some other
41-
values to the real feature weights.
42-
"""
43-
4437
def get_virtual_feature_ranges(self):
38+
"""
39+
This method returns the feature ranges for the virtual factors of the
40+
underlying feature blocks. This is useful to know during initialization,
41+
when we want to zero initialize the virtual feature weights, but give some other
42+
values to the real feature weights.
43+
"""
4544
ranges = []
4645
offset = 0
4746
for feature in self.features:
@@ -62,14 +61,13 @@ def get_real_feature_ranges(self):
6261

6362
return ranges
6463

65-
"""
66-
This method goes over all of the feature blocks and gathers the active features.
67-
Each block has its own index space assigned so the features from two different
68-
blocks will never have the same index here. Basically the thing you would expect
69-
to happen after concatenating many feature blocks.
70-
"""
71-
7264
def get_active_features(self, board):
65+
"""
66+
This method goes over all of the feature blocks and gathers the active features.
67+
Each block has its own index space assigned so the features from two different
68+
blocks will never have the same index here. Basically the thing you would expect
69+
to happen after concatenating many feature blocks.
70+
"""
7371
w = torch.zeros(0)
7472
b = torch.zeros(0)
7573

@@ -84,13 +82,12 @@ def get_active_features(self, board):
8482

8583
return w, b
8684

87-
"""
88-
This method takes a feature idx and looks for the block that owns it.
89-
If it found the block it asks it to factorize the index, otherwise
90-
it throws and Exception. The idx must refer to a real feature.
91-
"""
92-
9385
def get_feature_factors(self, idx):
86+
"""
87+
This method takes a feature idx and looks for the block that owns it.
88+
If it found the block it asks it to factorize the index, otherwise
89+
it throws and Exception. The idx must refer to a real feature.
90+
"""
9491
offset = 0
9592
for feature in self.features:
9693
if idx < offset + feature.num_real_features:
@@ -99,18 +96,17 @@ def get_feature_factors(self, idx):
9996

10097
raise Exception("No feature block to factorize {}".format(idx))
10198

102-
"""
103-
This method does what get_feature_factors does but for all
104-
valid features at the same time. It returns a list of length
105-
self.num_real_features with ith element being a list of factors
106-
of the ith feature.
107-
This method is technically redundant but it allows to perform the operation
108-
slightly faster when there's many feature blocks. It might be worth
109-
to add a similar method to the FeatureBlock itself - to make it faster
110-
for feature blocks with many factors.
111-
"""
112-
11399
def get_virtual_to_real_features_gather_indices(self):
100+
"""
101+
This method does what get_feature_factors does but for all
102+
valid features at the same time. It returns a list of length
103+
self.num_real_features with ith element being a list of factors
104+
of the ith feature.
105+
This method is technically redundant but it allows to perform the operation
106+
slightly faster when there's many feature blocks. It might be worth
107+
to add a similar method to the FeatureBlock itself - to make it faster
108+
for feature blocks with many factors.
109+
"""
114110
indices = []
115111
real_offset = 0
116112
offset = 0

0 commit comments

Comments
 (0)