Skip to content

Commit 0ef8846

Browse files
authored
Split NNUEModel from NNUE (official-stockfish#347)
ensure cuda-graphs still work at standard performance
1 parent 5a68d22 commit 0ef8846

File tree

8 files changed

+112
-114
lines changed

8 files changed

+112
-114
lines changed

data_loader/dataset.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
from . import stream
88
from .config import DataloaderSkipConfig
99

10-
from typing import List
11-
1210

1311
class FenBatchProvider:
1412
def __init__(
@@ -64,7 +62,7 @@ def __init__(
6462
destroy_stream,
6563
fetch_next,
6664
destroy_part,
67-
filenames: List[str],
65+
filenames: list[str],
6866
cyclic,
6967
num_workers,
7068
batch_size=None,
@@ -116,7 +114,7 @@ class SparseBatchProvider(TrainingDataProvider):
116114
def __init__(
117115
self,
118116
feature_set: str,
119-
filenames: List[str],
117+
filenames: list[str],
120118
batch_size,
121119
cyclic=True,
122120
num_workers=1,
@@ -140,7 +138,7 @@ class SparseBatchDataset(torch.utils.data.IterableDataset):
140138
def __init__(
141139
self,
142140
feature_set: str,
143-
filenames: List[str],
141+
filenames: list[str],
144142
batch_size,
145143
cyclic=True,
146144
num_workers=1,

data_loader/stream.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from ._native import c_lib, SparseBatchPtr, FenBatchPtr
44
from .config import CDataloaderSkipConfig, DataloaderSkipConfig
55
from features.feature_set import FeatureSet
6-
from typing import List
76

87

98
def _to_c_str_array(str_list):
@@ -14,7 +13,7 @@ def _to_c_str_array(str_list):
1413

1514
def create_fen_batch_stream(
1615
concurrency,
17-
filenames: List[str],
16+
filenames: list[str],
1817
batch_size,
1918
cyclic,
2019
config: DataloaderSkipConfig,
@@ -44,7 +43,7 @@ def destroy_fen_batch(fen_batch: FenBatchPtr):
4443
def create_sparse_batch_stream(
4544
feature_set: str,
4645
concurrency,
47-
filenames: List[str],
46+
filenames: list[str],
4847
batch_size,
4948
cyclic,
5049
config: DataloaderSkipConfig,

ftperm.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -517,8 +517,6 @@ def gather_impl(model, dataset, count):
517517
ZERO_POINT = 0.0 # Vary this to check hypothetical forced larger truncation to zero
518518
BATCH_SIZE = 1000
519519

520-
old_device = model.device
521-
522520
quantized_model = copy.deepcopy(model)
523521
quantize_ft(quantized_model)
524522
quantized_model.cuda()

model.py

Lines changed: 80 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import lightning as L
99
from dataclasses import dataclass
1010
from features.feature_set import FeatureSet
11-
from typing import List, Tuple
1211

1312
# 3 layer fully connected network
1413
L1 = 3072
@@ -43,10 +42,6 @@ def __init__(self, count: int):
4342
self.l2 = nn.Linear(L2 * 2, L3 * count)
4443
self.output = nn.Linear(L3, 1 * count)
4544

46-
# Cached helper tensor for choosing outputs by bucket indices.
47-
# Initialized lazily in forward.
48-
self.idx_offset = None
49-
5045
self._init_layers()
5146

5247
def _init_layers(self):
@@ -83,9 +78,14 @@ def _init_layers(self):
8378
self.output.bias = nn.Parameter(output_bias)
8479

8580
def forward(self, x: Tensor, ls_indices: Tensor):
86-
assert self.idx_offset is not None and self.idx_offset.shape[0] == x.shape[0]
81+
idx_offset = torch.arange(
82+
0,
83+
x.shape[0] * self.count,
84+
self.count,
85+
device=x.device
86+
)
8787

88-
indices = ls_indices.flatten() + self.idx_offset
88+
indices = ls_indices.flatten() + idx_offset
8989

9090
l1s_ = self.l1(x).reshape((-1, self.count, L2 + 1))
9191
l1f_ = self.l1_fact(x)
@@ -135,45 +135,22 @@ def get_coalesced_layer_stacks(self):
135135
yield l1, l2, output
136136

137137

138-
class NNUE(L.LightningModule):
139-
"""
140-
feature_set - an instance of FeatureSet defining the input features
141-
142-
lambda_ = 0.0 - purely based on game results
143-
0.0 < lambda_ < 1.0 - interpolated score and result
144-
lambda_ = 1.0 - purely based on search scores
145-
146-
gamma - the multiplicative factor applied to the learning rate after each epoch
147-
148-
lr - the initial learning rate
149-
"""
150-
138+
class NNUEModel(nn.Module):
151139
def __init__(
152140
self,
153141
feature_set: FeatureSet,
154-
max_epoch=800,
155-
num_batches_per_epoch=int(100_000_000 / 16384),
156-
gamma=0.992,
157-
lr=8.75e-4,
158-
param_index=0,
159-
num_psqt_buckets=8,
160-
num_ls_buckets=8,
161-
loss_params=LossParams(),
142+
num_psqt_buckets: int = 8,
143+
num_ls_buckets: int = 8,
162144
):
163-
super(NNUE, self).__init__()
145+
super().__init__()
164146
self.num_psqt_buckets = num_psqt_buckets
165147
self.num_ls_buckets = num_ls_buckets
148+
166149
self.input = DoubleFeatureTransformerSlice(
167150
feature_set.num_features, L1 + self.num_psqt_buckets
168151
)
169152
self.feature_set = feature_set
170153
self.layer_stacks = LayerStacks(self.num_ls_buckets)
171-
self.loss_params = loss_params
172-
self.max_epoch = max_epoch
173-
self.num_batches_per_epoch = num_batches_per_epoch
174-
self.gamma = gamma
175-
self.lr = lr
176-
self.param_index = param_index
177154

178155
self.nnue2score = 600.0
179156
self.weight_scale_hidden = 64.0
@@ -205,22 +182,21 @@ def __init__(
205182

206183
self._init_layers()
207184

208-
"""
209-
We zero all virtual feature weights because there's not need for them
210-
to be initialized; they only aid the training of correlated features.
211-
"""
185+
def _init_layers(self):
186+
self._zero_virtual_feature_weights()
187+
self._init_psqt()
212188

213189
def _zero_virtual_feature_weights(self):
190+
"""
191+
We zero all virtual feature weights because there's not need for them
192+
to be initialized; they only aid the training of correlated features.
193+
"""
214194
weights = self.input.weight
215195
with torch.no_grad():
216196
for a, b in self.feature_set.get_virtual_feature_ranges():
217197
weights[a:b, :] = 0.0
218198
self.input.weight = nn.Parameter(weights)
219199

220-
def _init_layers(self):
221-
self._zero_virtual_feature_weights()
222-
self._init_psqt()
223-
224200
def _init_psqt(self):
225201
input_weights = self.input.weight
226202
input_bias = self.input.bias
@@ -251,12 +227,11 @@ def _init_psqt(self):
251227
self.input.weight = nn.Parameter(input_weights)
252228
self.input.bias = nn.Parameter(input_bias)
253229

254-
"""
255-
Clips the weights of the model based on the min/max values allowed
256-
by the quantization scheme.
257-
"""
258-
259230
def _clip_weights(self):
231+
"""
232+
Clips the weights of the model based on the min/max values allowed
233+
by the quantization scheme.
234+
"""
260235
for group in self.weight_clipping:
261236
for p in group["params"]:
262237
if "min_weight" in group or "max_weight" in group:
@@ -287,12 +262,11 @@ def _clip_weights(self):
287262
raise Exception("Not supported.")
288263
p.data.copy_(p_data_fp32)
289264

290-
"""
291-
This method attempts to convert the model from using the self.feature_set
292-
to new_feature_set. Currently only works for adding virtual features.
293-
"""
294-
295265
def set_feature_set(self, new_feature_set: FeatureSet):
266+
"""
267+
This method attempts to convert the model from using the self.feature_set
268+
to new_feature_set. Currently only works for adding virtual features.
269+
"""
296270
if self.feature_set.name == new_feature_set.name:
297271
return
298272

@@ -370,13 +344,51 @@ def forward(
370344

371345
return x
372346

373-
def step_(self, batch: Tuple[Tensor, ...], batch_idx, loss_type):
347+
348+
class NNUE(L.LightningModule):
349+
"""
350+
feature_set - an instance of FeatureSet defining the input features
351+
352+
lambda_ = 0.0 - purely based on game results
353+
0.0 < lambda_ < 1.0 - interpolated score and result
354+
lambda_ = 1.0 - purely based on search scores
355+
356+
gamma - the multiplicative factor applied to the learning rate after each epoch
357+
358+
lr - the initial learning rate
359+
"""
360+
361+
def __init__(
362+
self,
363+
feature_set: FeatureSet,
364+
max_epoch=800,
365+
num_batches_per_epoch=int(100_000_000 / 16384),
366+
gamma=0.992,
367+
lr=8.75e-4,
368+
param_index=0,
369+
num_psqt_buckets=8,
370+
num_ls_buckets=8,
371+
loss_params=LossParams(),
372+
):
373+
super().__init__()
374+
self.model: NNUEModel = NNUEModel(feature_set, num_psqt_buckets, num_ls_buckets)
375+
self.loss_params = loss_params
376+
self.max_epoch = max_epoch
377+
self.num_batches_per_epoch = num_batches_per_epoch
378+
self.gamma = gamma
379+
self.lr = lr
380+
self.param_index = param_index
381+
382+
def forward(self, *args, **kwargs):
383+
return self.model(*args, **kwargs)
384+
385+
def step_(self, batch: tuple[Tensor, ...], batch_idx, loss_type):
374386
_ = batch_idx # unused, but required by pytorch-lightning
375387

376388
# We clip weights at the start of each step. This means that after
377389
# the last step the weights might be outside of the desired range.
378390
# They should be also clipped accordingly in the serializer.
379-
self._clip_weights()
391+
self.model._clip_weights()
380392

381393
(
382394
us,
@@ -392,7 +404,7 @@ def step_(self, batch: Tuple[Tensor, ...], batch_idx, loss_type):
392404
) = batch
393405

394406
scorenet = (
395-
self(
407+
self.model(
396408
us,
397409
them,
398410
white_indices,
@@ -402,7 +414,7 @@ def step_(self, batch: Tuple[Tensor, ...], batch_idx, loss_type):
402414
psqt_indices,
403415
layer_stack_indices,
404416
)
405-
* self.nnue2score
417+
* self.model.nnue2score
406418
)
407419

408420
p = self.loss_params
@@ -445,15 +457,15 @@ def test_step(self, batch, batch_idx):
445457
def configure_optimizers(self):
446458
LR = self.lr
447459
train_params = [
448-
{"params": get_parameters([self.input]), "lr": LR, "gc_dim": 0},
449-
{"params": [self.layer_stacks.l1_fact.weight], "lr": LR},
450-
{"params": [self.layer_stacks.l1_fact.bias], "lr": LR},
451-
{"params": [self.layer_stacks.l1.weight], "lr": LR},
452-
{"params": [self.layer_stacks.l1.bias], "lr": LR},
453-
{"params": [self.layer_stacks.l2.weight], "lr": LR},
454-
{"params": [self.layer_stacks.l2.bias], "lr": LR},
455-
{"params": [self.layer_stacks.output.weight], "lr": LR},
456-
{"params": [self.layer_stacks.output.bias], "lr": LR},
460+
{"params": get_parameters([self.model.input]), "lr": LR, "gc_dim": 0},
461+
{"params": [self.model.layer_stacks.l1_fact.weight], "lr": LR},
462+
{"params": [self.model.layer_stacks.l1_fact.bias], "lr": LR},
463+
{"params": [self.model.layer_stacks.l1.weight], "lr": LR},
464+
{"params": [self.model.layer_stacks.l1.bias], "lr": LR},
465+
{"params": [self.model.layer_stacks.l2.weight], "lr": LR},
466+
{"params": [self.model.layer_stacks.l2.bias], "lr": LR},
467+
{"params": [self.model.layer_stacks.output.weight], "lr": LR},
468+
{"params": [self.model.layer_stacks.output.bias], "lr": LR},
457469
]
458470

459471
optimizer = ranger21.Ranger21(
@@ -479,7 +491,7 @@ def configure_optimizers(self):
479491
return [optimizer], [scheduler]
480492

481493

482-
def coalesce_ft_weights(model: NNUE, layer: BaseFeatureTransformerSlice):
494+
def coalesce_ft_weights(model: NNUEModel, layer: BaseFeatureTransformerSlice):
483495
weight = layer.weight.data
484496
indices = model.feature_set.get_virtual_to_real_features_gather_indices()
485497
weight_coalesced = weight.new_zeros(
@@ -492,5 +504,5 @@ def coalesce_ft_weights(model: NNUE, layer: BaseFeatureTransformerSlice):
492504
return weight_coalesced
493505

494506

495-
def get_parameters(layers: List[nn.Module]):
507+
def get_parameters(layers: list[nn.Module]):
496508
return [p for layer in layers for p in layer.parameters()]

0 commit comments

Comments
 (0)