Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 22 additions & 14 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

# 3 layer fully connected network
L1 = 1024
L2 = 8
L2 = 15
L3 = 32

def coalesce_ft_weights(model, layer):
Expand All @@ -27,13 +27,13 @@ def __init__(self, count):
super(LayerStacks, self).__init__()

self.count = count
self.l1 = nn.Linear(2 * L1, L2 * count)
self.l1 = nn.Linear(2 * L1 // 2, (L2 + 1) * count)
# Factorizer only for the first layer because later
# there's a non-linearity and factorization breaks.
# It breaks the min/max weight clipping but hopefully it's not bad.
# TODO: try solving it
# one potential solution would be to coalesce the weights on each step.
self.l1_fact = nn.Linear(2 * L1, L2, bias=False)
self.l1_fact = nn.Linear(2 * L1 // 2, L2 + 1, bias=False)
self.l2 = nn.Linear(L2, L3 * count)
self.output = nn.Linear(L3, 1 * count)

Expand All @@ -56,8 +56,8 @@ def _init_layers(self):
for i in range(1, self.count):
# Make all layer stacks have the same initialization.
# Basically copy the first to all other layer stacks.
l1_weight[i*L2:(i+1)*L2, :] = l1_weight[0:L2, :]
l1_bias[i*L2:(i+1)*L2] = l1_bias[0:L2]
l1_weight[i*(L2+1):(i+1)*(L2+1), :] = l1_weight[0:(L2+1), :]
l1_bias[i*(L2+1):(i+1)*(L2+1)] = l1_bias[0:(L2+1)]
l2_weight[i*L3:(i+1)*L3, :] = l2_weight[0:L3, :]
l2_bias[i*L3:(i+1)*L3] = l2_bias[0:L3]
output_weight[i:i+1, :] = output_weight[0:1, :]
Expand All @@ -77,12 +77,14 @@ def forward(self, x, ls_indices):

indices = ls_indices.flatten() + self.idx_offset

l1s_ = self.l1(x).reshape((-1, self.count, L2))
l1s_ = self.l1(x).reshape((-1, self.count, L2 + 1))
l1f_ = self.l1_fact(x)
# https://stackoverflow.com/questions/55881002/pytorch-tensor-indexing-how-to-gather-rows-by-tensor-containing-indices
# basically we present it as a list of individual results and pick not only based on
# the ls index but also based on batch (they are combined into one index)
l1c_ = l1s_.view(-1, L2)[indices]
l1c_ = l1s_.view(-1, L2 + 1)[indices]
l1c_, l1c_out = l1c_.split(L2, dim=1)
l1f_, l1f_out = l1f_.split(L2, dim=1)
l1x_ = torch.clamp(l1c_ + l1f_, 0.0, 1.0)

l2s_ = self.l2(l1x_).reshape((-1, self.count, L3))
Expand All @@ -91,18 +93,18 @@ def forward(self, x, ls_indices):

l3s_ = self.output(l2x_).reshape((-1, self.count, 1))
l3c_ = l3s_.view(-1, 1)[indices]
l3x_ = l3c_
l3x_ = l3c_ + l1f_out + l1c_out

return l3x_

def get_coalesced_layer_stacks(self):
for i in range(self.count):
with torch.no_grad():
l1 = nn.Linear(2*L1, L2)
l1 = nn.Linear(2*L1 // 2, L2+1)
l2 = nn.Linear(L2, L3)
output = nn.Linear(L3, 1)
l1.weight.data = self.l1.weight[i*L2:(i+1)*L2, :] + self.l1_fact.weight.data
l1.bias.data = self.l1.bias[i*L2:(i+1)*L2]
l1.weight.data = self.l1.weight[i*(L2+1):(i+1)*(L2+1), :] + self.l1_fact.weight.data
l1.bias.data = self.l1.bias[i*(L2+1):(i+1)*(L2+1)]
l2.weight.data = self.l2.weight[i*L3:(i+1)*L3, :]
l2.bias.data = self.l2.bias[i*L3:(i+1)*L3]
output.weight.data = self.output.weight[i:(i+1), :]
Expand All @@ -119,14 +121,16 @@ class NNUE(pl.LightningModule):

It is not ideal for training a Pytorch quantized model directly.
"""
def __init__(self, feature_set, lambda_=1.0):
def __init__(self, feature_set, lambda_=1.0, gamma=0.992, lr=8.75e-4):
super(NNUE, self).__init__()
self.num_psqt_buckets = feature_set.num_psqt_buckets
self.num_ls_buckets = feature_set.num_ls_buckets
self.input = DoubleFeatureTransformerSlice(feature_set.num_features, L1 + self.num_psqt_buckets)
self.feature_set = feature_set
self.layer_stacks = LayerStacks(self.num_ls_buckets)
self.lambda_ = lambda_
self.gamma = gamma
self.lr = lr

self.weight_clipping = [
{'params' : [self.layer_stacks.l1.weight], 'min_weight' : -127/64, 'max_weight' : 127/64, 'virtual_params' : self.layer_stacks.l1_fact.weight },
Expand Down Expand Up @@ -249,6 +253,10 @@ def forward(self, us, them, white_indices, white_values, black_indices, black_va
# clamp here is used as a clipped relu to (0.0, 1.0)
l0_ = torch.clamp(l0_, 0.0, 1.0)

l0_s = torch.split(l0_, L1 // 2, dim=1)
l0_s1 = [l0_s[0] * l0_s[1], l0_s[2] * l0_s[3]]
l0_ = torch.cat(l0_s1, dim=1) * (127/128)

psqt_indices_unsq = psqt_indices.unsqueeze(dim=1)
wpsqt = wpsqt.gather(1, psqt_indices_unsq)
bpsqt = bpsqt.gather(1, psqt_indices_unsq)
Expand Down Expand Up @@ -295,7 +303,7 @@ def test_step(self, batch, batch_idx):

def configure_optimizers(self):
# Train with a lower LR on the output layer
LR = 8.75e-4
LR = self.lr
train_params = [
{'params' : get_parameters([self.input]), 'lr' : LR, 'gc_dim' : 0 },
{'params' : [self.layer_stacks.l1_fact.weight], 'lr' : LR },
Expand All @@ -309,5 +317,5 @@ def configure_optimizers(self):
# increasing the eps leads to less saturated nets with a few dead neurons
optimizer = ranger.Ranger(train_params, betas=(.9, 0.999), eps=1.0e-7, gc_loc=False, use_gc=False)
# Drop learning rate after 75 epochs
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.992)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=self.gamma)
return [optimizer], [scheduler]
6 changes: 3 additions & 3 deletions serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,15 +140,15 @@ def __init__(self, f, feature_set):
self.read_int32(feature_set.hash ^ (M.L1*2)) # Feature transformer hash
self.read_feature_transformer(self.model.input, self.model.num_psqt_buckets)
for i in range(self.model.num_ls_buckets):
l1 = nn.Linear(2*M.L1, M.L2)
l1 = nn.Linear(2*M.L1//2, M.L2+1)
l2 = nn.Linear(M.L2, M.L3)
output = nn.Linear(M.L3, 1)
self.read_int32(fc_hash) # FC layers hash
self.read_fc_layer(l1)
self.read_fc_layer(l2)
self.read_fc_layer(output, is_output=True)
self.model.layer_stacks.l1.weight.data[i*M.L2:(i+1)*M.L2, :] = l1.weight
self.model.layer_stacks.l1.bias.data[i*M.L2:(i+1)*M.L2] = l1.bias
self.model.layer_stacks.l1.weight.data[i*(M.L2+1):(i+1)*(M.L2+1), :] = l1.weight
self.model.layer_stacks.l1.bias.data[i*(M.L2+1):(i+1)*(M.L2+1)] = l1.bias
self.model.layer_stacks.l2.weight.data[i*M.L3:(i+1)*M.L3, :] = l2.weight
self.model.layer_stacks.l2.bias.data[i*M.L3:(i+1)*M.L3] = l2.bias
self.model.layer_stacks.output.weight.data[i:(i+1), :] = output.weight
Expand Down
8 changes: 7 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def main():
parser.add_argument("val", help="Validation data (.bin or .binpack)")
parser = pl.Trainer.add_argparse_args(parser)
parser.add_argument("--lambda", default=1.0, type=float, dest='lambda_', help="lambda=1.0 = train on evaluations, lambda=0.0 = train on game results, interpolates between (default=1.0).")
parser.add_argument("--gamma", default=0.992, type=float, dest='gamma', help="Multiplicative factor applied to the learning rate after every epoch.")
parser.add_argument("--lr", default=8.75e-4, type=float, dest='lr', help="Initial learning rate.")
parser.add_argument("--num-workers", default=1, type=int, dest='num_workers', help="Number of worker threads to use for data loading. Currently only works well for binpack.")
parser.add_argument("--batch-size", default=-1, type=int, dest='batch_size', help="Number of positions per batch / per iteration. Default on GPU = 8192 on CPU = 128.")
parser.add_argument("--threads", default=-1, type=int, dest='threads', help="Number of torch threads to use. Default automatic (cores) .")
Expand All @@ -50,11 +52,15 @@ def main():
feature_set = features.get_feature_set_from_name(args.features)

if args.resume_from_model is None:
nnue = M.NNUE(feature_set=feature_set, lambda_=args.lambda_)
nnue = M.NNUE(feature_set=feature_set, lambda_=args.lambda_, gamma=args.gamma, lr=args.lr)
else:
nnue = torch.load(args.resume_from_model)
nnue.set_feature_set(feature_set)
nnue.lambda_ = args.lambda_
# we can set the following here just like that because when resuming
# from .pt the optimizer is only created after the training is started
nnue.gamma = args.gamma
nnue.lr = args.lr

print("Feature set: {}".format(feature_set.name))
print("Num real features: {}".format(feature_set.num_real_features))
Expand Down