Skip to content

Commit 9a6cfde

Browse files
authored
Modularize NNUE (official-stockfish#349)
* move weight clipping into a callback * modularize * fix argparse * clean up imports * ruff format * fix * a * refactor model configuration * fix load from checkpoint * fix * ruff format * fix reader model * fix * suggestiosn * fix
1 parent 0ef8846 commit 9a6cfde

File tree

15 files changed

+383
-308
lines changed

15 files changed

+383
-308
lines changed

cross_check_eval.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@
77
import features
88
import serialize
99
import data_loader
10-
from model import NNUE
11-
from features.feature_set import FeatureSet
10+
from model import NNUE, ModelConfig
11+
from features import FeatureSet
1212

1313

14-
def read_model(nnue_path, feature_set: FeatureSet):
14+
def read_model(nnue_path, feature_set: FeatureSet, config: ModelConfig):
1515
with open(nnue_path, "rb") as f:
16-
reader = serialize.NNUEReader(f, feature_set)
16+
reader = serialize.NNUEReader(f, feature_set, config)
1717
return reader.model
1818

1919

@@ -164,16 +164,19 @@ def main():
164164
parser.add_argument(
165165
"--count", type=int, default=100, help="number of datapoints to process"
166166
)
167+
parser.add_argument("--l1", type=int, default=ModelConfig().L1)
167168
features.add_argparse_args(parser)
168169
args = parser.parse_args()
169170

170171
batch_size = 1000
171172

172173
feature_set = features.get_feature_set_from_name(args.features)
173174
if args.checkpoint:
174-
model = NNUE.load_from_checkpoint(args.checkpoint, feature_set=feature_set)
175+
model = NNUE.load_from_checkpoint(
176+
args.checkpoint, feature_set=feature_set, config=ModelConfig(L1=args.l1)
177+
)
175178
else:
176-
model = read_model(args.net, feature_set)
179+
model = read_model(args.net, feature_set, ModelConfig(L1=args.l1))
177180
model.eval()
178181
fen_batch_provider = make_fen_batch_provider(args.data, batch_size)
179182

data_loader/stream.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from ._native import c_lib, SparseBatchPtr, FenBatchPtr
44
from .config import CDataloaderSkipConfig, DataloaderSkipConfig
5-
from features.feature_set import FeatureSet
5+
from features import FeatureSet
66

77

88
def _to_c_str_array(str_list):

features/__init__.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
1-
import argparse
2-
31
from .feature_set import FeatureSet
4-
import model as M
52

63
"""
74
Each module that defines feature blocks must be imported here and
@@ -16,11 +13,6 @@
1613
_feature_blocks_by_name = dict()
1714

1815

19-
class SetNetworkSize(argparse.Action):
20-
def __call__(self, parser, namespace, values, option_string=None):
21-
M.L1 = int(values)
22-
23-
2416
def _add_feature_block(feature_block_cls):
2517
feature_block = feature_block_cls()
2618
_feature_blocks_by_name[feature_block.name] = feature_block
@@ -59,7 +51,6 @@ def add_argparse_args(parser):
5951
help='The feature set to use. Can be a union of feature blocks (for example P+HalfKP). "^" denotes a factorized block. Currently available feature blocks are: '
6052
+ ", ".join(get_available_feature_blocks_names()),
6153
)
62-
parser.add_argument("--l1", type=int, default=M.L1, action=SetNetworkSize)
6354

6455

6556
def _init():

ftperm.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@
4343

4444
import data_loader
4545
import model as M
46-
from model import NNUE
47-
from features.feature_set import FeatureSet
46+
from model import NNUE, NNUEModel, ModelConfig
47+
from features import FeatureSet
4848

4949

5050
"""
@@ -341,14 +341,14 @@ def make_swaps_3(actmat, use_cupy=True):
341341
return cycles, total_improvement
342342

343343

344-
def find_perm_impl(actmat, use_cupy):
344+
def find_perm_impl(actmat, use_cupy, L1: int):
345345
actmat = np.reshape(actmat, (actmat.shape[0] * 2, actmat.shape[1] // 2))
346346
if use_cupy:
347347
actmat = cp.asarray(actmat, dtype=cp.int8)
348348
actmat_orig = actmat.copy()
349349

350350
total_score_change = 0
351-
perm = np.arange(M.L1 // 2)
351+
perm = np.arange(L1 // 2)
352352

353353
stages = [make_swaps_2, make_swaps_3]
354354
# The optimization routines are deterministic, so no need to retry.
@@ -396,9 +396,9 @@ def find_perm_impl(actmat, use_cupy):
396396
# -------------------------------------------------------------
397397

398398

399-
def read_model(nnue_path, feature_set: FeatureSet):
399+
def read_model(nnue_path, feature_set: FeatureSet, config: ModelConfig):
400400
with open(nnue_path, "rb") as f:
401-
reader = serialize.NNUEReader(f, feature_set)
401+
reader = serialize.NNUEReader(f, feature_set, config)
402402
return reader.model
403403

404404

@@ -441,12 +441,12 @@ def forward_ft(
441441
layer_stack_indices,
442442
):
443443
wp, bp = model.input(white_indices, white_values, black_indices, black_values)
444-
w, wpsqt = torch.split(wp, M.L1, dim=1)
445-
b, bpsqt = torch.split(bp, M.L1, dim=1)
444+
w, _ = torch.split(wp, model.L1, dim=1)
445+
b, _ = torch.split(bp, model.L1, dim=1)
446446
l0_ = (us * torch.cat([w, b], dim=1)) + (them * torch.cat([b, w], dim=1))
447447
l0_ = torch.clamp(l0_, 0.0, 127.0)
448448

449-
l0_s = torch.split(l0_, M.L1 // 2, dim=1)
449+
l0_s = torch.split(l0_, model.L1 // 2, dim=1)
450450
l0_s1 = [l0_s[0] * l0_s[1], l0_s[2] * l0_s[3]]
451451
# We multiply by 127/128 because in the quantized network 1.0 is represented by 127
452452
# and it's more efficient to divide by 128 instead.
@@ -551,9 +551,11 @@ def gather_impl(model, dataset, count):
551551
def command_gather(args):
552552
feature_set = features.get_feature_set_from_name(args.features)
553553
if args.checkpoint:
554-
model = NNUE.load_from_checkpoint(args.checkpoint, feature_set=feature_set)
554+
model = NNUE.load_from_checkpoint(
555+
args.checkpoint, feature_set=feature_set, config=ModelConfig(L1=args.l1)
556+
)
555557
else:
556-
model = read_model(args.net, feature_set)
558+
model = read_model(args.net, feature_set, ModelConfig(L1=args.l1))
557559

558560
model.eval()
559561

@@ -600,13 +602,13 @@ def command_find_perm(args):
600602

601603
perm = find_perm_impl(actmat, args.use_cupy)
602604

603-
# perm = np.random.permutation([i for i in range(M.L1)])
605+
# perm = np.random.permutation([i for i in range(L1)])
604606
with open(args.out, "wb") as file:
605607
np.save(file, perm)
606608

607609

608610
def ft_optimize(
609-
model,
611+
model: NNUEModel,
610612
dataset_path,
611613
count,
612614
actmat_save_path=None,
@@ -620,7 +622,7 @@ def ft_optimize(
620622
np.save(file, actmat)
621623

622624
print("Finding permutation...")
623-
perm = find_perm_impl(actmat, use_cupy)
625+
perm = find_perm_impl(actmat, use_cupy, model.L1)
624626
if actmat_save_path is not None:
625627
with open(perm_save_path, "wb") as file:
626628
np.save(file, perm)
@@ -666,6 +668,7 @@ def main():
666668
parser_gather.add_argument(
667669
"--out", type=str, help="Filename under which to save the resulting ft matrix"
668670
)
671+
parser_gather.add_argument("--l1", type=int, default=M.ModelConfig().L1)
669672
features.add_argparse_args(parser_gather)
670673
parser_gather.set_defaults(func=command_gather)
671674

model/__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from .callbacks import WeightClippingCallback
2+
from .config import ModelConfig, LossParams
3+
from .lightning_module import NNUE
4+
from .model import NNUEModel
5+
from .utils import coalesce_ft_weights
6+
7+
8+
__all__ = [
9+
"WeightClippingCallback",
10+
"ModelConfig",
11+
"LossParams",
12+
"NNUE",
13+
"NNUEModel",
14+
"coalesce_ft_weights",
15+
]

model/callbacks.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import lightning as L
2+
3+
4+
class WeightClippingCallback(L.Callback):
5+
def on_train_batch_start(
6+
self,
7+
trainer: L.Trainer,
8+
pl_module: L.LightningModule,
9+
batch,
10+
batch_idx: int,
11+
) -> None:
12+
pl_module.model.clip_weights()

model/config.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from dataclasses import dataclass
2+
3+
4+
# 3 layer fully connected network
5+
@dataclass
6+
class ModelConfig:
7+
L1: int = 3072
8+
L2: int = 15
9+
L3: int = 32
10+
11+
12+
# parameters needed for the definition of the loss
13+
@dataclass
14+
class LossParams:
15+
in_offset: float = 270
16+
out_offset: float = 270
17+
in_scaling: float = 340
18+
out_scaling: float = 380
19+
start_lambda: float = 1.0
20+
end_lambda: float = 1.0
21+
pow_exp: float = 2.5
22+
qp_asymmetry: float = 0.0

model/lightning_module.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
import lightning as L
2+
import ranger21
3+
import torch
4+
from torch import Tensor
5+
6+
from features import FeatureSet
7+
from .config import LossParams, ModelConfig
8+
from .model import NNUEModel
9+
from .utils import get_parameters
10+
11+
12+
class NNUE(L.LightningModule):
13+
"""
14+
feature_set - an instance of FeatureSet defining the input features
15+
16+
lambda_ = 0.0 - purely based on game results
17+
0.0 < lambda_ < 1.0 - interpolated score and result
18+
lambda_ = 1.0 - purely based on search scores
19+
20+
gamma - the multiplicative factor applied to the learning rate after each epoch
21+
22+
lr - the initial learning rate
23+
"""
24+
25+
def __init__(
26+
self,
27+
feature_set: FeatureSet,
28+
config: ModelConfig,
29+
max_epoch=800,
30+
num_batches_per_epoch=int(100_000_000 / 16384),
31+
gamma=0.992,
32+
lr=8.75e-4,
33+
param_index=0,
34+
num_psqt_buckets=8,
35+
num_ls_buckets=8,
36+
loss_params=LossParams(),
37+
):
38+
super().__init__()
39+
self.model: NNUEModel = NNUEModel(
40+
feature_set, config, num_psqt_buckets, num_ls_buckets
41+
)
42+
self.loss_params = loss_params
43+
self.max_epoch = max_epoch
44+
self.num_batches_per_epoch = num_batches_per_epoch
45+
self.gamma = gamma
46+
self.lr = lr
47+
self.param_index = param_index
48+
49+
def forward(self, *args, **kwargs):
50+
return self.model(*args, **kwargs)
51+
52+
def step_(self, batch: tuple[Tensor, ...], batch_idx, loss_type):
53+
_ = batch_idx # unused, but required by pytorch-lightning
54+
55+
(
56+
us,
57+
them,
58+
white_indices,
59+
white_values,
60+
black_indices,
61+
black_values,
62+
outcome,
63+
score,
64+
psqt_indices,
65+
layer_stack_indices,
66+
) = batch
67+
68+
scorenet = (
69+
self.model(
70+
us,
71+
them,
72+
white_indices,
73+
white_values,
74+
black_indices,
75+
black_values,
76+
psqt_indices,
77+
layer_stack_indices,
78+
)
79+
* self.model.nnue2score
80+
)
81+
82+
p = self.loss_params
83+
# convert the network and search scores to an estimate match result
84+
# based on the win_rate_model, with scalings and offsets optimized
85+
q = (scorenet - p.in_offset) / p.in_scaling
86+
qm = (-scorenet - p.in_offset) / p.in_scaling
87+
qf = 0.5 * (1.0 + q.sigmoid() - qm.sigmoid())
88+
89+
s = (score - p.out_offset) / p.out_scaling
90+
sm = (-score - p.out_offset) / p.out_scaling
91+
pf = 0.5 * (1.0 + s.sigmoid() - sm.sigmoid())
92+
93+
# blend that eval based score with the actual game outcome
94+
t = outcome
95+
actual_lambda = p.start_lambda + (p.end_lambda - p.start_lambda) * (
96+
self.current_epoch / self.max_epoch
97+
)
98+
pt = pf * actual_lambda + t * (1.0 - actual_lambda)
99+
100+
# use a MSE-like loss function
101+
loss = torch.pow(torch.abs(pt - qf), p.pow_exp)
102+
if p.qp_asymmetry != 0.0:
103+
loss = loss * ((qf > pt) * p.qp_asymmetry + 1)
104+
loss = loss.mean()
105+
106+
self.log(loss_type, loss)
107+
108+
return loss
109+
110+
def training_step(self, batch, batch_idx):
111+
return self.step_(batch, batch_idx, "train_loss")
112+
113+
def validation_step(self, batch, batch_idx):
114+
self.step_(batch, batch_idx, "val_loss")
115+
116+
def test_step(self, batch, batch_idx):
117+
self.step_(batch, batch_idx, "test_loss")
118+
119+
def configure_optimizers(self):
120+
LR = self.lr
121+
train_params = [
122+
{"params": get_parameters([self.model.input]), "lr": LR, "gc_dim": 0},
123+
{"params": [self.model.layer_stacks.l1_fact.weight], "lr": LR},
124+
{"params": [self.model.layer_stacks.l1_fact.bias], "lr": LR},
125+
{"params": [self.model.layer_stacks.l1.weight], "lr": LR},
126+
{"params": [self.model.layer_stacks.l1.bias], "lr": LR},
127+
{"params": [self.model.layer_stacks.l2.weight], "lr": LR},
128+
{"params": [self.model.layer_stacks.l2.bias], "lr": LR},
129+
{"params": [self.model.layer_stacks.output.weight], "lr": LR},
130+
{"params": [self.model.layer_stacks.output.bias], "lr": LR},
131+
]
132+
133+
optimizer = ranger21.Ranger21(
134+
train_params,
135+
lr=1.0,
136+
betas=(0.9, 0.999),
137+
eps=1.0e-7,
138+
using_gc=False,
139+
using_normgc=False,
140+
weight_decay=0.0,
141+
num_batches_per_epoch=self.num_batches_per_epoch,
142+
num_epochs=self.max_epoch,
143+
warmdown_active=False,
144+
use_warmup=False,
145+
use_adaptive_gradient_clipping=False,
146+
softplus=False,
147+
pnm_momentum_factor=0.0,
148+
)
149+
150+
scheduler = torch.optim.lr_scheduler.StepLR(
151+
optimizer, step_size=1, gamma=self.gamma
152+
)
153+
154+
return [optimizer], [scheduler]

0 commit comments

Comments
 (0)