Skip to content

Commit cc44eea

Browse files
authored
Backport changes needed for master nets (#375)
this introduces two changes needed for training SF master nets, in particular for executing the recipes referenced in official-stockfish/Stockfish#6452 official-stockfish/Stockfish#6457 adding some additional flexibility for shaping the piece count distribution and weighting the individual configurations in the loss respectively.
1 parent d3d0b07 commit cc44eea

File tree

5 files changed

+92
-24
lines changed

5 files changed

+92
-24
lines changed

data_loader/config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ class DataloaderSkipConfig:
1010
early_fen_skipping: int = -1
1111
simple_eval_skipping: int = -1
1212
param_index: int = 0
13+
pc_y1: float = 1.0
14+
pc_y2: float = 2.0
15+
pc_y3: float = 1.0
1316

1417

1518
class CDataloaderSkipConfig(ctypes.Structure):
@@ -20,6 +23,9 @@ class CDataloaderSkipConfig(ctypes.Structure):
2023
("early_fen_skipping", ctypes.c_int),
2124
("simple_eval_skipping", ctypes.c_int),
2225
("param_index", ctypes.c_int),
26+
("pc_y1", ctypes.c_double),
27+
("pc_y2", ctypes.c_double),
28+
("pc_y3", ctypes.c_double),
2329
]
2430

2531
def __init__(self, config: DataloaderSkipConfig):
@@ -30,4 +36,7 @@ def __init__(self, config: DataloaderSkipConfig):
3036
early_fen_skipping=config.early_fen_skipping,
3137
simple_eval_skipping=config.simple_eval_skipping,
3238
param_index=config.param_index,
39+
pc_y1=config.pc_y1,
40+
pc_y2=config.pc_y2,
41+
pc_y3=config.pc_y3,
3342
)

model/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,5 @@ class LossParams:
2020
end_lambda: float = 1.0
2121
pow_exp: float = 2.5
2222
qp_asymmetry: float = 0.0
23+
w1: float = 0.0
24+
w2: float = 0.5

model/lightning_module.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,9 @@ def step_(self, batch: tuple[Tensor, ...], batch_idx, loss_type):
106106
loss = torch.pow(torch.abs(pt - qf), p.pow_exp)
107107
if p.qp_asymmetry != 0.0:
108108
loss = loss * ((qf > pt) * p.qp_asymmetry + 1)
109-
loss = loss.mean()
109+
110+
weights = 1 + (2.0**p.w1 - 1) * torch.pow((pf - 0.5) ** 2 * pf * (1 - pf), p.w2)
111+
loss = (loss * weights).sum() / weights.sum()
110112

111113
self.log(loss_type, loss, prog_bar=True)
112114

train.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,42 @@ def main():
333333
dest="simple_eval_skipping",
334334
help="Skip positions that have abs(simple_eval(pos)) < n",
335335
)
336+
parser.add_argument(
337+
"--pc-y1",
338+
type=float,
339+
default=1.0,
340+
dest="pc_y1",
341+
help="piece count parameter y1 (default=1.0)",
342+
)
343+
parser.add_argument(
344+
"--pc-y2",
345+
type=float,
346+
default=2.0,
347+
dest="pc_y2",
348+
help="piece count parameter y2 (default=2.0)",
349+
)
350+
parser.add_argument(
351+
"--pc-y3",
352+
type=float,
353+
default=1.0,
354+
dest="pc_y3",
355+
help="piece count parameter y3 (default=1.0)",
356+
)
357+
parser.add_argument(
358+
"--w1",
359+
type=float,
360+
default=0.0,
361+
dest="w1",
362+
help="weight boost parameter 1 (default=0.0)",
363+
)
364+
parser.add_argument(
365+
"--w2",
366+
type=float,
367+
default=0.5,
368+
dest="w2",
369+
help="weight boost parameter 2 (default=0.5)",
370+
)
371+
336372
parser.add_argument("--l1", type=int, default=M.ModelConfig().L1)
337373
M.add_feature_args(parser)
338374
args = parser.parse_args()
@@ -377,6 +413,8 @@ def main():
377413
end_lambda=args.end_lambda or args.lambda_,
378414
pow_exp=args.pow_exp,
379415
qp_asymmetry=args.qp_asymmetry,
416+
w1=args.w1,
417+
w2=args.w2,
380418
)
381419
print("Loss parameters:")
382420
print(loss_params)
@@ -429,6 +467,11 @@ def main():
429467
print("Skip early plies: {}".format(args.early_fen_skipping))
430468
print("Skip simple eval : {}".format(args.simple_eval_skipping))
431469
print("Param index: {}".format(args.param_index))
470+
print("piececount param y1 : {}".format(args.pc_y1))
471+
print("piececount param y2 : {}".format(args.pc_y2))
472+
print("piececount param y3 : {}".format(args.pc_y3))
473+
print("Weighting param w1 : {}".format(args.w1))
474+
print("Weighting param w2 : {}".format(args.w2))
432475

433476
if args.threads > 0:
434477
print("limiting torch to {} threads.".format(args.threads))
@@ -481,6 +524,9 @@ def main():
481524
early_fen_skipping=args.early_fen_skipping,
482525
simple_eval_skipping=args.simple_eval_skipping,
483526
param_index=args.param_index,
527+
pc_y1=args.pc_y1,
528+
pc_y2=args.pc_y2,
529+
pc_y3=args.pc_y3,
484530
),
485531
args.epoch_size,
486532
args.validation_size,

training_data_loader.cpp

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1029,12 +1029,13 @@ struct FenBatchStream: Stream<FenBatch> {
10291029
};
10301030

10311031
struct DataloaderSkipConfig {
1032-
bool filtered;
1033-
int random_fen_skipping;
1034-
bool wld_filtered;
1035-
int early_fen_skipping;
1036-
int simple_eval_skipping;
1037-
int param_index;
1032+
bool filtered;
1033+
int random_fen_skipping;
1034+
bool wld_filtered;
1035+
int early_fen_skipping;
1036+
int simple_eval_skipping;
1037+
int param_index;
1038+
double pc_y1, pc_y2, pc_y3;
10381039
};
10391040

10401041
std::function<bool(const TrainingDataEntry&)> make_skip_predicate(DataloaderSkipConfig config) {
@@ -1049,19 +1050,17 @@ std::function<bool(const TrainingDataEntry&)> make_skip_predicate(DataloaderSkip
10491050
// compression ability.
10501051
static constexpr int VALUE_NONE = 32002;
10511052

1052-
static constexpr double desired_piece_count_weights[33] = {
1053-
1.000000, 1.121094, 1.234375, 1.339844, 1.437500, 1.527344, 1.609375,
1054-
1.683594, 1.750000, 1.808594, 1.859375, 1.902344, 1.937500, 1.964844,
1055-
1.984375, 1.996094, 2.000000, 1.996094, 1.984375, 1.964844, 1.937500,
1056-
1.902344, 1.859375, 1.808594, 1.750000, 1.683594, 1.609375, 1.527344,
1057-
1.437500, 1.339844, 1.234375, 1.121094, 1.000000};
1058-
1059-
static constexpr double desired_piece_count_weights_total = []() {
1060-
double tot = 0;
1061-
for (auto w : desired_piece_count_weights)
1062-
tot += w;
1063-
return tot;
1064-
}();
1053+
// lagrange interpolation weights for desired piece count distribution
1054+
auto desired_piece_count_weights = [&config](int pc) -> double {
1055+
double x = pc;
1056+
double x1 = 0, y1 = config.pc_y1;
1057+
double x2 = 16, y2 = config.pc_y2;
1058+
double x3 = 32, y3 = config.pc_y3;
1059+
double l1 = (x - x2) * (x - x3) / ((x1 - x2) * (x1 - x3));
1060+
double l2 = (x - x1) * (x - x3) / ((x2 - x1) * (x2 - x3));
1061+
double l3 = (x - x1) * (x - x2) / ((x3 - x1) * (x3 - x2));
1062+
return l1 * y1 + l2 * y2 + l3 * y3;
1063+
};
10651064

10661065
// keep stats on passing pieces
10671066
static thread_local double alpha = 1;
@@ -1123,16 +1122,23 @@ std::function<bool(const TrainingDataEntry&)> make_skip_predicate(DataloaderSkip
11231122
piece_count_history_all[pc] += 1;
11241123
piece_count_history_all_total += 1;
11251124

1125+
double desired_piece_count_weights_total = [&desired_piece_count_weights]() {
1126+
double tot = 0;
1127+
for (int i = 0; i < 33; i++)
1128+
tot += desired_piece_count_weights(i);
1129+
return tot;
1130+
}();
1131+
11261132
// update alpha, which scales the filtering probability, to a maximum rate.
11271133
if (uint64_t(piece_count_history_all_total) % 10000 == 0)
11281134
{
11291135
double pass = piece_count_history_all_total * desired_piece_count_weights_total;
11301136
for (int i = 0; i < 33; ++i)
11311137
{
1132-
if (desired_piece_count_weights[pc] > 0)
1138+
if (desired_piece_count_weights(pc) > 0)
11331139
{
11341140
double tmp =
1135-
piece_count_history_all_total * desired_piece_count_weights[pc]
1141+
piece_count_history_all_total * desired_piece_count_weights(pc)
11361142
/ (desired_piece_count_weights_total * piece_count_history_all[pc]);
11371143
if (tmp < pass)
11381144
pass = tmp;
@@ -1141,7 +1147,7 @@ std::function<bool(const TrainingDataEntry&)> make_skip_predicate(DataloaderSkip
11411147
alpha = 1.0 / (pass * max_skipping_rate);
11421148
}
11431149

1144-
double tmp = alpha * piece_count_history_all_total * desired_piece_count_weights[pc]
1150+
double tmp = alpha * piece_count_history_all_total * desired_piece_count_weights(pc)
11451151
/ (desired_piece_count_weights_total * piece_count_history_all[pc]);
11461152
tmp = std::min(1.0, tmp);
11471153
std::bernoulli_distribution distrib(1.0 - tmp);
@@ -1366,7 +1372,10 @@ int main(int argc, char** argv) {
13661372
.wld_filtered = true,
13671373
.early_fen_skipping = 5,
13681374
.simple_eval_skipping = 0,
1369-
.param_index = 0};
1375+
.param_index = 0,
1376+
.pc_y1 = 1.0,
1377+
.pc_y2 = 2.0,
1378+
.pc_y3 = 1.0};
13701379
auto stream = create_sparse_batch_stream("Full_Threats^", concurrency, file_count, files,
13711380
batch_size, cyclic, config);
13721381

0 commit comments

Comments
 (0)