Skip to content

Commit b6865d5

Browse files
committed
Refined loss function
this refines the loss function to the form used for the new master net in official-stockfish/Stockfish#4100 The new loss function uses the expect game score to learn, making the the learning more sensitive to those scores between loss and draw, draw and win. Most visible for smaller values of the scaling parameter, but the current ones have been optimized. it also introduces param_index for simpler explorations of paramers, i.e. simple parameter scans.
1 parent 50eed1c commit b6865d5

File tree

2 files changed

+23
-11
lines changed

2 files changed

+23
-11
lines changed

model.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ class NNUE(pl.LightningModule):
132132
133133
lr - the initial learning rate
134134
"""
135-
def __init__(self, feature_set, start_lambda=1.0, end_lambda=1.0, max_epoch=800, gamma=0.992, lr=8.75e-4, num_psqt_buckets=8, num_ls_buckets=8):
135+
def __init__(self, feature_set, start_lambda=1.0, end_lambda=1.0, max_epoch=800, gamma=0.992, lr=8.75e-4, param_index=0, num_psqt_buckets=8, num_ls_buckets=8):
136136
super(NNUE, self).__init__()
137137
self.num_psqt_buckets = num_psqt_buckets
138138
self.num_ls_buckets = num_ls_buckets
@@ -144,6 +144,7 @@ def __init__(self, feature_set, start_lambda=1.0, end_lambda=1.0, max_epoch=800,
144144
self.max_epoch = max_epoch
145145
self.gamma = gamma
146146
self.lr = lr
147+
self.param_index = param_index
147148

148149
self.nnue2score = 600.0
149150
self.weight_scale_hidden = 64.0
@@ -292,19 +293,26 @@ def step_(self, batch, batch_idx, loss_type):
292293

293294
us, them, white_indices, white_values, black_indices, black_values, outcome, score, psqt_indices, layer_stack_indices = batch
294295

295-
# 600 is the kPonanzaConstant scaling factor needed to convert the training net output to a score.
296-
# This needs to match the value used in the serializer
297-
in_scaling = 410
298-
out_scaling = 361
296+
# convert the network and search scores to an estimate match result
297+
# based on the win_rate_model, with scalings and offsets optimized
298+
in_scaling = 340
299+
out_scaling = 380
300+
offset = 270
299301

300-
q = (self(us, them, white_indices, white_values, black_indices, black_values, psqt_indices, layer_stack_indices) * self.nnue2score / out_scaling).sigmoid()
301-
t = outcome
302-
p = (score / in_scaling).sigmoid()
302+
scorenet = self(us, them, white_indices, white_values, black_indices, black_values, psqt_indices, layer_stack_indices) * self.nnue2score
303+
q = ( scorenet - offset) / in_scaling # used to compute the chance of a win
304+
qm = (-scorenet - offset) / in_scaling # used to compute the chance of a loss
305+
qf = 0.5 * (1.0 + q.sigmoid() - qm.sigmoid()) # estimated match result (using win, loss and draw probs).
306+
307+
p = ( score - offset) / out_scaling
308+
pm = (-score - offset) / out_scaling
309+
pf = 0.5 * (1.0 + p.sigmoid() - pm.sigmoid())
303310

311+
t = outcome
304312
actual_lambda = self.start_lambda + (self.end_lambda - self.start_lambda) * (self.current_epoch / self.max_epoch)
305-
pt = p * actual_lambda + t * (1.0 - actual_lambda)
313+
pt = pf * actual_lambda + t * (1.0 - actual_lambda)
306314

307-
loss = torch.pow(torch.abs(pt - q), 2.6).mean()
315+
loss = torch.pow(torch.abs(pt - qf), 2.6).mean()
308316

309317
self.log(loss_type, loss)
310318

train.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def main():
5656
parser.add_argument("--save-last-network", type=str2bool, default=True, dest='save_last_network', help="Whether to always save the last produced network.")
5757
parser.add_argument("--epoch-size", type=int, default=100000000, dest='epoch_size', help="Number of positions per epoch.")
5858
parser.add_argument("--validation-size", type=int, default=1000000, dest='validation_size', help="Number of positions per validation step.")
59+
parser.add_argument("--param-index", type=int, default=0, dest='param_index', help="Indexing for parameter scans.")
5960
features.add_argparse_args(parser)
6061
args = parser.parse_args()
6162

@@ -79,7 +80,8 @@ def main():
7980
max_epoch=max_epoch,
8081
end_lambda=end_lambda,
8182
gamma=args.gamma,
82-
lr=args.lr
83+
lr=args.lr,
84+
param_index=args.param_index
8385
)
8486
else:
8587
nnue = torch.load(args.resume_from_model)
@@ -91,6 +93,7 @@ def main():
9193
# from .pt the optimizer is only created after the training is started
9294
nnue.gamma = args.gamma
9395
nnue.lr = args.lr
96+
nnue.param_index=args.param_index
9497

9598
print("Feature set: {}".format(feature_set.name))
9699
print("Num real features: {}".format(feature_set.num_real_features))
@@ -110,6 +113,7 @@ def main():
110113
print('Smart fen skipping: {}'.format(not args.no_smart_fen_skipping))
111114
print('WLD fen skipping: {}'.format(not args.no_wld_fen_skipping))
112115
print('Random fen skipping: {}'.format(args.random_fen_skipping))
116+
print('Param index: {}'.format(args.param_index))
113117

114118
if args.threads > 0:
115119
print('limiting torch to {} threads.'.format(args.threads))

0 commit comments

Comments
 (0)