Skip to content

Commit 4b1477a

Browse files
authored
Add FeatureSet and Model Type Hints (official-stockfish#343)
1 parent 2789ff3 commit 4b1477a

File tree

9 files changed

+82
-90
lines changed

9 files changed

+82
-90
lines changed

cross_check_eval.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
import serialize
99
import data_loader
1010
from model import NNUE
11+
from features.feature_set import FeatureSet
1112

1213

13-
def read_model(nnue_path, feature_set):
14+
def read_model(nnue_path, feature_set: FeatureSet):
1415
with open(nnue_path, "rb") as f:
1516
reader = serialize.NNUEReader(f, feature_set)
1617
return reader.model

data_loader/dataset.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def __del__(self):
5757
class TrainingDataProvider:
5858
def __init__(
5959
self,
60-
feature_set,
60+
feature_set: str,
6161
create_stream,
6262
destroy_stream,
6363
fetch_next,
@@ -113,7 +113,7 @@ def __del__(self):
113113
class SparseBatchProvider(TrainingDataProvider):
114114
def __init__(
115115
self,
116-
feature_set,
116+
feature_set: str,
117117
filenames,
118118
batch_size,
119119
cyclic=True,
@@ -137,7 +137,7 @@ def __init__(
137137
class SparseBatchDataset(torch.utils.data.IterableDataset):
138138
def __init__(
139139
self,
140-
feature_set,
140+
feature_set: str,
141141
filenames,
142142
batch_size,
143143
cyclic=True,

data_loader/stream.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from ._native import c_lib
44
from .config import CDataloaderSkipConfig, DataloaderSkipConfig
5+
from features.feature_set import FeatureSet
56

67

78
def _to_c_str_array(str_list):
@@ -40,7 +41,7 @@ def destroy_fen_batch(fen_batch):
4041

4142

4243
def create_sparse_batch_stream(
43-
feature_set,
44+
feature_set: str,
4445
concurrency,
4546
filenames,
4647
batch_size,
@@ -62,7 +63,7 @@ def destroy_sparse_batch_stream(stream):
6263
c_lib.dll.destroy_sparse_batch_stream(stream)
6364

6465

65-
def get_sparse_batch_from_fens(feature_set, fens, scores, plies, results):
66+
def get_sparse_batch_from_fens(feature_set: FeatureSet, fens, scores, plies, results):
6667
assert len(fens) == len(scores) == len(plies) == len(results)
6768

6869
def to_c_int_array(data):

feature_transformer.py

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -573,13 +573,14 @@ def backward(ctx, grad_output_0, grad_output_1):
573573
return None, None, None, None, weight_grad, bias_grad
574574

575575

576-
class FeatureTransformerSlice(nn.Module):
576+
class BaseFeatureTransformerSlice(nn.Module):
577577
def __init__(self, num_inputs, num_outputs):
578-
super(FeatureTransformerSlice, self).__init__()
578+
super(BaseFeatureTransformerSlice, self).__init__()
579579
self.num_inputs = num_inputs
580580
self.num_outputs = num_outputs
581581

582582
sigma = math.sqrt(1 / num_inputs)
583+
583584
self.weight = nn.Parameter(
584585
torch.rand(num_inputs, num_outputs, dtype=torch.float32) * (2 * sigma)
585586
- sigma
@@ -588,27 +589,15 @@ def __init__(self, num_inputs, num_outputs):
588589
torch.rand(num_outputs, dtype=torch.float32) * (2 * sigma) - sigma
589590
)
590591

592+
593+
class FeatureTransformerSlice(BaseFeatureTransformerSlice):
591594
def forward(self, feature_indices, feature_values):
592595
return FeatureTransformerSliceFunction.apply(
593596
feature_indices, feature_values, self.weight, self.bias
594597
)
595598

596599

597-
class DoubleFeatureTransformerSlice(nn.Module):
598-
def __init__(self, num_inputs, num_outputs):
599-
super(DoubleFeatureTransformerSlice, self).__init__()
600-
self.num_inputs = num_inputs
601-
self.num_outputs = num_outputs
602-
603-
sigma = math.sqrt(1 / num_inputs)
604-
self.weight = nn.Parameter(
605-
torch.rand(num_inputs, num_outputs, dtype=torch.float32) * (2 * sigma)
606-
- sigma
607-
)
608-
self.bias = nn.Parameter(
609-
torch.rand(num_outputs, dtype=torch.float32) * (2 * sigma) - sigma
610-
)
611-
600+
class DoubleFeatureTransformerSlice(BaseFeatureTransformerSlice):
612601
def forward(
613602
self, feature_indices_0, feature_values_0, feature_indices_1, feature_values_1
614603
):
@@ -624,8 +613,6 @@ def forward(
624613

625614
if __name__ == "__main__":
626615
import time
627-
import sys
628-
import os
629616

630617
def FeatureTransformerSliceFunctionEmulate(
631618
feature_indices, feature_values, weight, bias

ftperm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@
4444
import data_loader
4545
import model as M
4646
from model import NNUE
47+
from features.feature_set import FeatureSet
48+
4749

4850
"""
4951
@@ -394,7 +396,7 @@ def find_perm_impl(actmat, use_cupy):
394396
# -------------------------------------------------------------
395397

396398

397-
def read_model(nnue_path, feature_set):
399+
def read_model(nnue_path, feature_set: FeatureSet):
398400
with open(nnue_path, "rb") as f:
399401
reader = serialize.NNUEReader(f, feature_set)
400402
return reader.model

model.py

Lines changed: 37 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
import ranger21
22
import torch
3-
from torch import nn
3+
from torch import nn, Tensor
44
import pytorch_lightning as pl
5-
from feature_transformer import DoubleFeatureTransformerSlice
5+
from feature_transformer import (
6+
DoubleFeatureTransformerSlice,
7+
BaseFeatureTransformerSlice,
8+
)
69
from dataclasses import dataclass
10+
from features.feature_set import FeatureSet
11+
from typing import List, Tuple
712

813
# 3 layer fully connected network
914
L1 = 3072
@@ -24,25 +29,8 @@ class LossParams:
2429
qp_asymmetry: float = 0.0
2530

2631

27-
def coalesce_ft_weights(model, layer):
28-
weight = layer.weight.data
29-
indices = model.feature_set.get_virtual_to_real_features_gather_indices()
30-
weight_coalesced = weight.new_zeros(
31-
(model.feature_set.num_real_features, weight.shape[1])
32-
)
33-
for i_real, is_virtual in enumerate(indices):
34-
weight_coalesced[i_real, :] = sum(
35-
weight[i_virtual, :] for i_virtual in is_virtual
36-
)
37-
return weight_coalesced
38-
39-
40-
def get_parameters(layers):
41-
return [p for layer in layers for p in layer.parameters()]
42-
43-
4432
class LayerStacks(nn.Module):
45-
def __init__(self, count):
33+
def __init__(self, count: int):
4634
super(LayerStacks, self).__init__()
4735

4836
self.count = count
@@ -94,7 +82,7 @@ def _init_layers(self):
9482
self.output.weight = nn.Parameter(output_weight)
9583
self.output.bias = nn.Parameter(output_bias)
9684

97-
def forward(self, x, ls_indices):
85+
def forward(self, x: Tensor, ls_indices: Tensor):
9886
assert self.idx_offset is not None and self.idx_offset.shape[0] == x.shape[0]
9987

10088
indices = ls_indices.flatten() + self.idx_offset
@@ -162,7 +150,7 @@ class NNUE(pl.LightningModule):
162150

163151
def __init__(
164152
self,
165-
feature_set,
153+
feature_set: FeatureSet,
166154
max_epoch=800,
167155
num_batches_per_epoch=int(100_000_000 / 16384),
168156
gamma=0.992,
@@ -304,7 +292,7 @@ def _clip_weights(self):
304292
to new_feature_set. Currently only works for adding virtual features.
305293
"""
306294

307-
def set_feature_set(self, new_feature_set):
295+
def set_feature_set(self, new_feature_set: FeatureSet):
308296
if self.feature_set.name == new_feature_set.name:
309297
return
310298

@@ -351,14 +339,14 @@ def set_feature_set(self, new_feature_set):
351339

352340
def forward(
353341
self,
354-
us,
355-
them,
356-
white_indices,
357-
white_values,
358-
black_indices,
359-
black_values,
360-
psqt_indices,
361-
layer_stack_indices,
342+
us: Tensor,
343+
them: Tensor,
344+
white_indices: Tensor,
345+
white_values: Tensor,
346+
black_indices: Tensor,
347+
black_values: Tensor,
348+
psqt_indices: Tensor,
349+
layer_stack_indices: Tensor,
362350
):
363351
wp, bp = self.input(white_indices, white_values, black_indices, black_values)
364352
w, wpsqt = torch.split(wp, L1, dim=1)
@@ -382,7 +370,7 @@ def forward(
382370

383371
return x
384372

385-
def step_(self, batch, batch_idx, loss_type):
373+
def step_(self, batch: Tuple[Tensor, ...], batch_idx, loss_type):
386374
_ = batch_idx # unused, but required by pytorch-lightning
387375

388376
# We clip weights at the start of each step. This means that after
@@ -489,3 +477,20 @@ def configure_optimizers(self):
489477
optimizer, step_size=1, gamma=self.gamma
490478
)
491479
return [optimizer], [scheduler]
480+
481+
482+
def coalesce_ft_weights(model: NNUE, layer: BaseFeatureTransformerSlice):
483+
weight = layer.weight.data
484+
indices = model.feature_set.get_virtual_to_real_features_gather_indices()
485+
weight_coalesced = weight.new_zeros(
486+
(model.feature_set.num_real_features, weight.shape[1])
487+
)
488+
for i_real, is_virtual in enumerate(indices):
489+
weight_coalesced[i_real, :] = sum(
490+
weight[i_virtual, :] for i_virtual in is_virtual
491+
)
492+
return weight_coalesced
493+
494+
495+
def get_parameters(layers: List[nn.Module]):
496+
return [p for layer in layers for p in layer.parameters()]

0 commit comments

Comments
 (0)