Skip to content

Commit 4655d19

Browse files
authored
Merge pull request #164 from Sopel97/sf15a_trainer
SFNNv4 trainer
2 parents 04f80bd + 985a627 commit 4655d19

File tree

3 files changed

+32
-18
lines changed

3 files changed

+32
-18
lines changed

model.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
# 3 layer fully connected network
1010
L1 = 1024
11-
L2 = 8
11+
L2 = 15
1212
L3 = 32
1313

1414
def coalesce_ft_weights(model, layer):
@@ -27,13 +27,13 @@ def __init__(self, count):
2727
super(LayerStacks, self).__init__()
2828

2929
self.count = count
30-
self.l1 = nn.Linear(2 * L1, L2 * count)
30+
self.l1 = nn.Linear(2 * L1 // 2, (L2 + 1) * count)
3131
# Factorizer only for the first layer because later
3232
# there's a non-linearity and factorization breaks.
3333
# It breaks the min/max weight clipping but hopefully it's not bad.
3434
# TODO: try solving it
3535
# one potential solution would be to coalesce the weights on each step.
36-
self.l1_fact = nn.Linear(2 * L1, L2, bias=False)
36+
self.l1_fact = nn.Linear(2 * L1 // 2, L2 + 1, bias=False)
3737
self.l2 = nn.Linear(L2, L3 * count)
3838
self.output = nn.Linear(L3, 1 * count)
3939

@@ -56,8 +56,8 @@ def _init_layers(self):
5656
for i in range(1, self.count):
5757
# Make all layer stacks have the same initialization.
5858
# Basically copy the first to all other layer stacks.
59-
l1_weight[i*L2:(i+1)*L2, :] = l1_weight[0:L2, :]
60-
l1_bias[i*L2:(i+1)*L2] = l1_bias[0:L2]
59+
l1_weight[i*(L2+1):(i+1)*(L2+1), :] = l1_weight[0:(L2+1), :]
60+
l1_bias[i*(L2+1):(i+1)*(L2+1)] = l1_bias[0:(L2+1)]
6161
l2_weight[i*L3:(i+1)*L3, :] = l2_weight[0:L3, :]
6262
l2_bias[i*L3:(i+1)*L3] = l2_bias[0:L3]
6363
output_weight[i:i+1, :] = output_weight[0:1, :]
@@ -77,12 +77,14 @@ def forward(self, x, ls_indices):
7777

7878
indices = ls_indices.flatten() + self.idx_offset
7979

80-
l1s_ = self.l1(x).reshape((-1, self.count, L2))
80+
l1s_ = self.l1(x).reshape((-1, self.count, L2 + 1))
8181
l1f_ = self.l1_fact(x)
8282
# https://stackoverflow.com/questions/55881002/pytorch-tensor-indexing-how-to-gather-rows-by-tensor-containing-indices
8383
# basically we present it as a list of individual results and pick not only based on
8484
# the ls index but also based on batch (they are combined into one index)
85-
l1c_ = l1s_.view(-1, L2)[indices]
85+
l1c_ = l1s_.view(-1, L2 + 1)[indices]
86+
l1c_, l1c_out = l1c_.split(L2, dim=1)
87+
l1f_, l1f_out = l1f_.split(L2, dim=1)
8688
l1x_ = torch.clamp(l1c_ + l1f_, 0.0, 1.0)
8789

8890
l2s_ = self.l2(l1x_).reshape((-1, self.count, L3))
@@ -91,18 +93,18 @@ def forward(self, x, ls_indices):
9193

9294
l3s_ = self.output(l2x_).reshape((-1, self.count, 1))
9395
l3c_ = l3s_.view(-1, 1)[indices]
94-
l3x_ = l3c_
96+
l3x_ = l3c_ + l1f_out + l1c_out
9597

9698
return l3x_
9799

98100
def get_coalesced_layer_stacks(self):
99101
for i in range(self.count):
100102
with torch.no_grad():
101-
l1 = nn.Linear(2*L1, L2)
103+
l1 = nn.Linear(2*L1 // 2, L2+1)
102104
l2 = nn.Linear(L2, L3)
103105
output = nn.Linear(L3, 1)
104-
l1.weight.data = self.l1.weight[i*L2:(i+1)*L2, :] + self.l1_fact.weight.data
105-
l1.bias.data = self.l1.bias[i*L2:(i+1)*L2]
106+
l1.weight.data = self.l1.weight[i*(L2+1):(i+1)*(L2+1), :] + self.l1_fact.weight.data
107+
l1.bias.data = self.l1.bias[i*(L2+1):(i+1)*(L2+1)]
106108
l2.weight.data = self.l2.weight[i*L3:(i+1)*L3, :]
107109
l2.bias.data = self.l2.bias[i*L3:(i+1)*L3]
108110
output.weight.data = self.output.weight[i:(i+1), :]
@@ -119,14 +121,16 @@ class NNUE(pl.LightningModule):
119121
120122
It is not ideal for training a Pytorch quantized model directly.
121123
"""
122-
def __init__(self, feature_set, lambda_=1.0):
124+
def __init__(self, feature_set, lambda_=1.0, gamma=0.992, lr=8.75e-4):
123125
super(NNUE, self).__init__()
124126
self.num_psqt_buckets = feature_set.num_psqt_buckets
125127
self.num_ls_buckets = feature_set.num_ls_buckets
126128
self.input = DoubleFeatureTransformerSlice(feature_set.num_features, L1 + self.num_psqt_buckets)
127129
self.feature_set = feature_set
128130
self.layer_stacks = LayerStacks(self.num_ls_buckets)
129131
self.lambda_ = lambda_
132+
self.gamma = gamma
133+
self.lr = lr
130134

131135
self.weight_clipping = [
132136
{'params' : [self.layer_stacks.l1.weight], 'min_weight' : -127/64, 'max_weight' : 127/64, 'virtual_params' : self.layer_stacks.l1_fact.weight },
@@ -249,6 +253,10 @@ def forward(self, us, them, white_indices, white_values, black_indices, black_va
249253
# clamp here is used as a clipped relu to (0.0, 1.0)
250254
l0_ = torch.clamp(l0_, 0.0, 1.0)
251255

256+
l0_s = torch.split(l0_, L1 // 2, dim=1)
257+
l0_s1 = [l0_s[0] * l0_s[1], l0_s[2] * l0_s[3]]
258+
l0_ = torch.cat(l0_s1, dim=1) * (127/128)
259+
252260
psqt_indices_unsq = psqt_indices.unsqueeze(dim=1)
253261
wpsqt = wpsqt.gather(1, psqt_indices_unsq)
254262
bpsqt = bpsqt.gather(1, psqt_indices_unsq)
@@ -295,7 +303,7 @@ def test_step(self, batch, batch_idx):
295303

296304
def configure_optimizers(self):
297305
# Train with a lower LR on the output layer
298-
LR = 8.75e-4
306+
LR = self.lr
299307
train_params = [
300308
{'params' : get_parameters([self.input]), 'lr' : LR, 'gc_dim' : 0 },
301309
{'params' : [self.layer_stacks.l1_fact.weight], 'lr' : LR },
@@ -309,5 +317,5 @@ def configure_optimizers(self):
309317
# increasing the eps leads to less saturated nets with a few dead neurons
310318
optimizer = ranger.Ranger(train_params, betas=(.9, 0.999), eps=1.0e-7, gc_loc=False, use_gc=False)
311319
# Drop learning rate after 75 epochs
312-
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.992)
320+
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=self.gamma)
313321
return [optimizer], [scheduler]

serialize.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,15 +140,15 @@ def __init__(self, f, feature_set):
140140
self.read_int32(feature_set.hash ^ (M.L1*2)) # Feature transformer hash
141141
self.read_feature_transformer(self.model.input, self.model.num_psqt_buckets)
142142
for i in range(self.model.num_ls_buckets):
143-
l1 = nn.Linear(2*M.L1, M.L2)
143+
l1 = nn.Linear(2*M.L1//2, M.L2+1)
144144
l2 = nn.Linear(M.L2, M.L3)
145145
output = nn.Linear(M.L3, 1)
146146
self.read_int32(fc_hash) # FC layers hash
147147
self.read_fc_layer(l1)
148148
self.read_fc_layer(l2)
149149
self.read_fc_layer(output, is_output=True)
150-
self.model.layer_stacks.l1.weight.data[i*M.L2:(i+1)*M.L2, :] = l1.weight
151-
self.model.layer_stacks.l1.bias.data[i*M.L2:(i+1)*M.L2] = l1.bias
150+
self.model.layer_stacks.l1.weight.data[i*(M.L2+1):(i+1)*(M.L2+1), :] = l1.weight
151+
self.model.layer_stacks.l1.bias.data[i*(M.L2+1):(i+1)*(M.L2+1)] = l1.bias
152152
self.model.layer_stacks.l2.weight.data[i*M.L3:(i+1)*M.L3, :] = l2.weight
153153
self.model.layer_stacks.l2.bias.data[i*M.L3:(i+1)*M.L3] = l2.bias
154154
self.model.layer_stacks.output.weight.data[i:(i+1), :] = output.weight

train.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ def main():
3030
parser.add_argument("val", help="Validation data (.bin or .binpack)")
3131
parser = pl.Trainer.add_argparse_args(parser)
3232
parser.add_argument("--lambda", default=1.0, type=float, dest='lambda_', help="lambda=1.0 = train on evaluations, lambda=0.0 = train on game results, interpolates between (default=1.0).")
33+
parser.add_argument("--gamma", default=0.992, type=float, dest='gamma', help="Multiplicative factor applied to the learning rate after every epoch.")
34+
parser.add_argument("--lr", default=8.75e-4, type=float, dest='lr', help="Initial learning rate.")
3335
parser.add_argument("--num-workers", default=1, type=int, dest='num_workers', help="Number of worker threads to use for data loading. Currently only works well for binpack.")
3436
parser.add_argument("--batch-size", default=-1, type=int, dest='batch_size', help="Number of positions per batch / per iteration. Default on GPU = 8192 on CPU = 128.")
3537
parser.add_argument("--threads", default=-1, type=int, dest='threads', help="Number of torch threads to use. Default automatic (cores) .")
@@ -50,11 +52,15 @@ def main():
5052
feature_set = features.get_feature_set_from_name(args.features)
5153

5254
if args.resume_from_model is None:
53-
nnue = M.NNUE(feature_set=feature_set, lambda_=args.lambda_)
55+
nnue = M.NNUE(feature_set=feature_set, lambda_=args.lambda_, gamma=args.gamma, lr=args.lr)
5456
else:
5557
nnue = torch.load(args.resume_from_model)
5658
nnue.set_feature_set(feature_set)
5759
nnue.lambda_ = args.lambda_
60+
# we can set the following here just like that because when resuming
61+
# from .pt the optimizer is only created after the training is started
62+
nnue.gamma = args.gamma
63+
nnue.lr = args.lr
5864

5965
print("Feature set: {}".format(feature_set.name))
6066
print("Num real features: {}".format(feature_set.num_real_features))

0 commit comments

Comments
 (0)