Skip to content

Commit 9da251d

Browse files
authored
Merge pull request #143 from Sopel97/exp135_master_pr
Update trainer for 2021-08-09 new master net
2 parents f7f77c0 + 6545a64 commit 9da251d

File tree

5 files changed

+215
-16
lines changed

5 files changed

+215
-16
lines changed

README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,11 @@ python train.py --resume_from_checkpoint <path> ...
5454
python train.py --gpus 1 ...
5555
```
5656
## Feature set selection
57-
By default the trainer uses a factorized HalfKAv2 feature set (named "HalfKAv2^")
57+
By default the trainer uses a factorized HalfKAv2_hm feature set (named "HalfKAv2_hm^")
5858
If you wish to change the feature set used then you can use the `--features=NAME` option. For the list of available features see `--help`
5959
The default is:
6060
```
61-
python train.py ... --features="HalfKAv2^"
61+
python train.py ... --features="HalfKAv2_hm^"
6262
```
6363

6464
## Skipping certain fens in the training
@@ -69,7 +69,7 @@ python train.py ... --features="HalfKAv2^"
6969
## Current recommended training invocation
7070

7171
```
72-
python train.py --smart-fen-skipping --random-fen-skipping 3 --batch-size 16384 --threads 8 --num-workers 8 --gpus 1 trainingdata validationdata
72+
python train.py --smart-fen-skipping --random-fen-skipping 3 --batch-size 16384 --threads 2 --num-workers 2 --gpus 1 trainingdata validationdata
7373
```
7474
best nets have been trained with 16B d9-scored nets, training runs >200 epochs
7575

@@ -96,13 +96,13 @@ python serialize.py nn.nnue converted.pt
9696
Visualize a network from either a checkpoint (`.ckpt`), a serialized model (`.pt`)
9797
or a SF NNUE file (`.nnue`).
9898
```
99-
python visualize.py nn.nnue --features="HalfKAv2"
99+
python visualize.py nn.nnue --features="HalfKAv2_hm"
100100
```
101101

102102
Visualize the difference between two networks from either a checkpoint (`.ckpt`), a serialized model (`.pt`)
103103
or a SF NNUE file (`.nnue`).
104104
```
105-
python visualize.py nn.nnue --features="HalfKAv2" --ref-model nn.cpkt --ref-features="HalfKAv2^"
105+
python visualize.py nn.nnue --features="HalfKAv2_hm" --ref-model nn.cpkt --ref-features="HalfKAv2_hm^"
106106
```
107107

108108
# Logging

features.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
import halfkp
1313
import halfka
1414
import halfka_v2
15+
import halfka_v2_hm
1516

16-
_feature_modules = [halfkp, halfka, halfka_v2]
17+
_feature_modules = [halfkp, halfka, halfka_v2, halfka_v2_hm]
1718

1819
_feature_blocks_by_name = dict()
1920

@@ -41,7 +42,7 @@ def get_available_feature_blocks_names():
4142
return list(iter(_feature_blocks_by_name))
4243

4344
def add_argparse_args(parser):
44-
_default_feature_set_name = 'HalfKAv2^'
45+
_default_feature_set_name = 'HalfKAv2_hm^'
4546
parser.add_argument("--features", dest='features', default=_default_feature_set_name, help="The feature set to use. Can be a union of feature blocks (for example P+HalfKP). \"^\" denotes a factorized block. Currently available feature blocks are: " + ', '.join(get_available_feature_blocks_names()))
4647

4748
def _init():

halfka_v2_hm.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import chess
2+
import torch
3+
import feature_block
4+
from collections import OrderedDict
5+
from feature_block import *
6+
7+
NUM_SQ = 64
8+
NUM_PT_REAL = 11
9+
NUM_PT_VIRTUAL = 12
10+
NUM_PLANES_REAL = NUM_SQ * NUM_PT_REAL
11+
NUM_PLANES_VIRTUAL = NUM_SQ * NUM_PT_VIRTUAL
12+
NUM_INPUTS = NUM_PLANES_REAL * NUM_SQ // 2
13+
14+
KingBuckets = [
15+
-1, -1, -1, -1, 31, 30, 29, 28,
16+
-1, -1, -1, -1, 27, 26, 25, 24,
17+
-1, -1, -1, -1, 23, 22, 21, 20,
18+
-1, -1, -1, -1, 19, 18, 17, 16,
19+
-1, -1, -1, -1, 15, 14, 13, 12,
20+
-1, -1, -1, -1, 11, 10, 9, 8,
21+
-1, -1, -1, -1, 7, 6, 5, 4,
22+
-1, -1, -1, -1, 3, 2, 1, 0
23+
]
24+
25+
def orient(is_white_pov: bool, sq: int, ksq: int):
26+
# ksq must not be oriented
27+
kfile = (ksq % 8)
28+
return (7 * (kfile < 4)) ^ (56 * (not is_white_pov)) ^ sq
29+
30+
def halfka_idx(is_white_pov: bool, king_sq: int, sq: int, p: chess.Piece):
31+
p_idx = (p.piece_type - 1) * 2 + (p.color != is_white_pov)
32+
o_ksq = orient(is_white_pov, king_sq, king_sq)
33+
if p_idx == 11:
34+
p_idx -= 1
35+
return orient(is_white_pov, sq, king_sq) + p_idx * NUM_SQ + KingBuckets[o_ksq] * NUM_PLANES_REAL
36+
37+
def halfka_psqts():
38+
# values copied from stockfish, in stockfish internal units
39+
piece_values = {
40+
chess.PAWN : 126,
41+
chess.KNIGHT : 781,
42+
chess.BISHOP : 825,
43+
chess.ROOK : 1276,
44+
chess.QUEEN : 2538
45+
}
46+
47+
values = [0] * NUM_INPUTS
48+
49+
for ksq in range(64):
50+
for s in range(64):
51+
for pt, val in piece_values.items():
52+
idxw = halfka_idx(True, ksq, s, chess.Piece(pt, chess.WHITE))
53+
idxb = halfka_idx(True, ksq, s, chess.Piece(pt, chess.BLACK))
54+
values[idxw] = val
55+
values[idxb] = -val
56+
57+
return values
58+
59+
class Features(FeatureBlock):
60+
def __init__(self):
61+
super(Features, self).__init__('HalfKAv2_hm', 0x7f234cb8, OrderedDict([('HalfKAv2_hm', NUM_INPUTS)]))
62+
63+
def get_active_features(self, board: chess.Board):
64+
raise Exception('Not supported yet, you must use the c++ data loader for support during training')
65+
66+
def get_initial_psqt_features(self):
67+
return halfka_psqts()
68+
69+
class FactorizedFeatures(FeatureBlock):
70+
def __init__(self):
71+
super(FactorizedFeatures, self).__init__('HalfKAv2_hm^', 0x7f234cb8, OrderedDict([('HalfKAv2_hm', NUM_INPUTS), ('A', NUM_PLANES_VIRTUAL)]))
72+
73+
def get_active_features(self, board: chess.Board):
74+
raise Exception('Not supported yet, you must use the c++ data loader for factorizer support during training')
75+
76+
def get_feature_factors(self, idx):
77+
if idx >= self.num_real_features:
78+
raise Exception('Feature must be real')
79+
80+
a_idx = idx % NUM_PLANES_REAL
81+
k_idx = idx // NUM_PLANES_REAL
82+
83+
if a_idx // NUM_SQ == 10 and k_idx != KingBuckets[a_idx % NUM_SQ]:
84+
a_idx += NUM_SQ
85+
86+
return [idx, self.get_factor_base_feature('A') + a_idx]
87+
88+
def get_initial_psqt_features(self):
89+
return halfka_psqts() + [0] * NUM_PLANES_VIRTUAL
90+
91+
'''
92+
This is used by the features module for discovery of feature blocks.
93+
'''
94+
def get_feature_block_clss():
95+
return [Features, FactorizedFeatures]

model.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
from feature_transformer import DoubleFeatureTransformerSlice
88

99
# 3 layer fully connected network
10-
L1 = 512
11-
L2 = 16
10+
L1 = 1024
11+
L2 = 8
1212
L3 = 32
1313

1414
def coalesce_ft_weights(model, layer):
@@ -271,9 +271,9 @@ def step_(self, batch, batch_idx, loss_type):
271271
t = outcome
272272
p = (score / in_scaling).sigmoid()
273273

274-
loss_eval = (p - q).square().mean()
275-
loss_result = (q - t).square().mean()
276-
loss = self.lambda_ * loss_eval + (1.0 - self.lambda_) * loss_result
274+
pt = p * self.lambda_ + t * (1.0 - self.lambda_)
275+
276+
loss = torch.pow(torch.abs(pt - q), 2.6).mean()
277277

278278
self.log(loss_type, loss)
279279

@@ -295,19 +295,19 @@ def test_step(self, batch, batch_idx):
295295

296296
def configure_optimizers(self):
297297
# Train with a lower LR on the output layer
298-
LR = 1.5e-3
298+
LR = 8.75e-4
299299
train_params = [
300300
{'params' : get_parameters([self.input]), 'lr' : LR, 'gc_dim' : 0 },
301301
{'params' : [self.layer_stacks.l1_fact.weight], 'lr' : LR },
302302
{'params' : [self.layer_stacks.l1.weight], 'lr' : LR },
303303
{'params' : [self.layer_stacks.l1.bias], 'lr' : LR },
304304
{'params' : [self.layer_stacks.l2.weight], 'lr' : LR },
305305
{'params' : [self.layer_stacks.l2.bias], 'lr' : LR },
306-
{'params' : [self.layer_stacks.output.weight], 'lr' : LR / 10 },
307-
{'params' : [self.layer_stacks.output.bias], 'lr' : LR / 10 },
306+
{'params' : [self.layer_stacks.output.weight], 'lr' : LR },
307+
{'params' : [self.layer_stacks.output.bias], 'lr' : LR },
308308
]
309309
# increasing the eps leads to less saturated nets with a few dead neurons
310310
optimizer = ranger.Ranger(train_params, betas=(.9, 0.999), eps=1.0e-7, gc_loc=False, use_gc=False)
311311
# Drop learning rate after 75 epochs
312-
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.987)
312+
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.992)
313313
return [optimizer], [scheduler]

training_data_loader.cpp

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,93 @@ struct HalfKAv2Factorized {
258258
}
259259
};
260260

261+
// ksq must not be oriented
262+
static Square orient_flip_2(Color color, Square sq, Square ksq)
263+
{
264+
bool h = ksq.file() < fileE;
265+
if (color == Color::Black)
266+
sq = sq.flippedVertically();
267+
if (h)
268+
sq = sq.flippedHorizontally();
269+
return sq;
270+
}
271+
272+
struct HalfKAv2_hm {
273+
static constexpr int NUM_SQ = 64;
274+
static constexpr int NUM_PT = 11;
275+
static constexpr int NUM_PLANES = NUM_SQ * NUM_PT;
276+
static constexpr int INPUTS = NUM_PLANES * NUM_SQ / 2;
277+
278+
static constexpr int MAX_ACTIVE_FEATURES = 32;
279+
280+
static constexpr int KingBuckets[64] = {
281+
-1, -1, -1, -1, 31, 30, 29, 28,
282+
-1, -1, -1, -1, 27, 26, 25, 24,
283+
-1, -1, -1, -1, 23, 22, 21, 20,
284+
-1, -1, -1, -1, 19, 18, 17, 16,
285+
-1, -1, -1, -1, 15, 14, 13, 12,
286+
-1, -1, -1, -1, 11, 10, 9, 8,
287+
-1, -1, -1, -1, 7, 6, 5, 4,
288+
-1, -1, -1, -1, 3, 2, 1, 0
289+
};
290+
291+
static int feature_index(Color color, Square ksq, Square sq, Piece p)
292+
{
293+
Square o_ksq = orient_flip_2(color, ksq, ksq);
294+
auto p_idx = static_cast<int>(p.type()) * 2 + (p.color() != color);
295+
if (p_idx == 11)
296+
--p_idx; // pack the opposite king into the same NUM_SQ * NUM_SQ
297+
return static_cast<int>(orient_flip_2(color, sq, ksq)) + p_idx * NUM_SQ + KingBuckets[static_cast<int>(o_ksq)] * NUM_PLANES;
298+
}
299+
300+
static std::pair<int, int> fill_features_sparse(const TrainingDataEntry& e, int* features, float* values, Color color)
301+
{
302+
auto& pos = e.pos;
303+
auto pieces = pos.piecesBB();
304+
auto ksq = pos.kingSquare(color);
305+
306+
int j = 0;
307+
for(Square sq : pieces)
308+
{
309+
auto p = pos.pieceAt(sq);
310+
values[j] = 1.0f;
311+
features[j] = feature_index(color, ksq, sq, p);
312+
++j;
313+
}
314+
315+
return { j, INPUTS };
316+
}
317+
};
318+
319+
struct HalfKAv2_hmFactorized {
320+
// Factorized features
321+
static constexpr int PIECE_INPUTS = HalfKAv2_hm::NUM_SQ * HalfKAv2_hm::NUM_PT;
322+
static constexpr int INPUTS = HalfKAv2_hm::INPUTS + PIECE_INPUTS;
323+
324+
static constexpr int MAX_PIECE_FEATURES = 32;
325+
static constexpr int MAX_ACTIVE_FEATURES = HalfKAv2_hm::MAX_ACTIVE_FEATURES + MAX_PIECE_FEATURES;
326+
327+
static std::pair<int, int> fill_features_sparse(const TrainingDataEntry& e, int* features, float* values, Color color)
328+
{
329+
const auto [start_j, offset] = HalfKAv2_hm::fill_features_sparse(e, features, values, color);
330+
auto& pos = e.pos;
331+
auto pieces = pos.piecesBB();
332+
auto ksq = pos.kingSquare(color);
333+
334+
int j = start_j;
335+
for(Square sq : pieces)
336+
{
337+
auto p = pos.pieceAt(sq);
338+
auto p_idx = static_cast<int>(p.type()) * 2 + (p.color() != color);
339+
values[j] = 1.0f;
340+
features[j] = offset + (p_idx * HalfKAv2_hm::NUM_SQ) + static_cast<int>(orient_flip_2(color, sq, ksq));
341+
++j;
342+
}
343+
344+
return { j, INPUTS };
345+
}
346+
};
347+
261348
template <typename T, typename... Ts>
262349
struct FeatureSet
263350
{
@@ -797,6 +884,14 @@ extern "C" {
797884
{
798885
return new SparseBatch(FeatureSet<HalfKAv2Factorized>{}, entries);
799886
}
887+
else if (feature_set == "HalfKAv2_hm")
888+
{
889+
return new SparseBatch(FeatureSet<HalfKAv2_hm>{}, entries);
890+
}
891+
else if (feature_set == "HalfKAv2_hm^")
892+
{
893+
return new SparseBatch(FeatureSet<HalfKAv2_hmFactorized>{}, entries);
894+
}
800895
fprintf(stderr, "Unknown feature_set %s\n", feature_set_c);
801896
return nullptr;
802897
}
@@ -842,6 +937,14 @@ extern "C" {
842937
{
843938
return new FeaturedBatchStream<FeatureSet<HalfKAv2Factorized>, SparseBatch>(concurrency, filename, batch_size, cyclic, skipPredicate);
844939
}
940+
else if (feature_set == "HalfKAv2_hm")
941+
{
942+
return new FeaturedBatchStream<FeatureSet<HalfKAv2_hm>, SparseBatch>(concurrency, filename, batch_size, cyclic, skipPredicate);
943+
}
944+
else if (feature_set == "HalfKAv2_hm^")
945+
{
946+
return new FeaturedBatchStream<FeatureSet<HalfKAv2_hmFactorized>, SparseBatch>(concurrency, filename, batch_size, cyclic, skipPredicate);
947+
}
845948
fprintf(stderr, "Unknown feature_set %s\n", feature_set_c);
846949
return nullptr;
847950
}

0 commit comments

Comments
 (0)