Skip to content

Commit ac8d6e1

Browse files
committed
New NNUE architecture, (768->1024)x2->1
1 parent 51409b5 commit ac8d6e1

File tree

2 files changed

+39
-76
lines changed

2 files changed

+39
-76
lines changed

model.py

Lines changed: 38 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,62 +1,38 @@
1-
import torch
2-
import struct
3-
4-
from torch import nn
5-
6-
7-
# Number of inputs
8-
NUM_SQ = 64
9-
NUM_PT = 12
10-
NUM_INPUTS = NUM_SQ*NUM_PT
11-
12-
# 3 layer fully connected network
13-
L1 = 384
14-
L2 = 8
15-
L3 = 16
16-
17-
class NNUE(nn.Module):
18-
def __init__(self):
19-
super(NNUE, self).__init__()
20-
self.input = nn.Linear(NUM_INPUTS, L1)
21-
self.l1 = nn.Linear(2 * L1, L2)
22-
self.l2 = nn.Linear(L2, L3)
23-
self.output = nn.Linear(L3, 1)
24-
25-
26-
def forward(self, us, them, w_in, b_in):
27-
w = self.input(w_in)
28-
b = self.input(b_in)
29-
l0_ = (us*torch.cat([w, b], dim=1)) + (them*torch.cat([b, w], dim=1))
30-
l0_ = torch.clamp(l0_, 0.0, 1.0)
31-
l1_ = torch.clamp(self.l1(l0_), 0.0, 1.0)
32-
l2_ = torch.clamp(self.l2(l1_), 0.0, 1.0)
33-
x = self.output(l2_)
34-
return x
35-
36-
37-
def clamp_weights(self):
38-
# L1
39-
data = self.l1.weight.data
40-
data.clamp(-127.0/64.0, 127.0/64.0)
41-
self.l1.weight.data.copy_(data)
42-
43-
# L2
44-
data = self.l2.weight.data
45-
data.clamp(-127.0/64.0, 127.0/64.0)
46-
self.l2.weight.data.copy_(data)
47-
48-
# Output
49-
data = self.output.weight.data
50-
data.clamp(-127.0*127.0/64.0, 127.0*127.0/64.0)
51-
self.output.weight.data.copy_(data)
52-
53-
54-
def loss_function(wdl, pred, batch):
55-
us, them, white, black, outcome, score = batch
56-
57-
wdl_eval_model = (pred*600.0/361).sigmoid()
58-
wdl_eval_target = (score/410).sigmoid()
59-
60-
wdl_value_target = wdl_eval_target * (1.0 - wdl) + outcome * wdl
61-
62-
return torch.abs(wdl_value_target - wdl_eval_model).square().mean()
1+
import torch
2+
import struct
3+
4+
from torch import nn
5+
6+
7+
# Number of inputs
8+
NUM_SQ = 64
9+
NUM_PT = 12
10+
NUM_INPUTS = NUM_SQ*NUM_PT
11+
12+
L1 = 1024
13+
14+
class NNUE(nn.Module):
15+
def __init__(self):
16+
super(NNUE, self).__init__()
17+
self.input = nn.Linear(NUM_INPUTS, L1)
18+
self.output = nn.Linear(2*L1, 1)
19+
20+
21+
def forward(self, us, them, w_in, b_in):
22+
w = self.input(w_in)
23+
b = self.input(b_in)
24+
l0_ = (us*torch.cat([w, b], dim=1)) + (them*torch.cat([b, w], dim=1))
25+
l0_ = torch.clamp(l0_, 0.0, 1.0)
26+
x = self.output(l0_)
27+
return x
28+
29+
30+
def loss_function(wdl, pred, batch):
31+
us, them, white, black, outcome, score = batch
32+
33+
wdl_eval_model = (pred*600.0/361).sigmoid()
34+
wdl_eval_target = (score/410).sigmoid()
35+
36+
wdl_value_target = wdl_eval_target * (1.0 - wdl) + outcome * wdl
37+
38+
return torch.abs(wdl_value_target - wdl_eval_model).square().mean()

quantize.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,11 @@
1212

1313
HALFKX_WEIGHT_SCALE = MAX_QUANTIZED_ACTIVATION
1414
HALFKX_BIAS_SCALE = MAX_QUANTIZED_ACTIVATION
15-
HIDDEN_WEIGHT_SCALE = (1<<WEIGHT_SCALE_BITS)
16-
HIDDEN_BIAS_SCALE = (1<<WEIGHT_SCALE_BITS)*MAX_QUANTIZED_ACTIVATION
1715
OUTPUT_WEIGHT_SCALE = (OUTPUT_SCALE*NNUE2SCORE/MAX_QUANTIZED_ACTIVATION)
1816
OUTPUT_BIAS_SCALE = OUTPUT_SCALE*NNUE2SCORE
19-
MAX_HIDDEN_WEIGHT = MAX_QUANTIZED_ACTIVATION/HIDDEN_WEIGHT_SCALE
2017
MAX_OUTPUT_WEIGHT = MAX_QUANTIZED_ACTIVATION/OUTPUT_WEIGHT_SCALE
2118

22-
NNUE_FORMAT_VERSION = 0x00000009
19+
NNUE_FORMAT_VERSION = 0x0000000A
2320

2421
def write_header(buf, version):
2522
buf.extend(struct.pack('<I', version))
@@ -41,12 +38,6 @@ def quant_input(biases, weights):
4138
return (biases, weights)
4239

4340

44-
def quant_linear(biases, weights):
45-
biases = biases.data.mul(HIDDEN_BIAS_SCALE).round().to(torch.int32)
46-
weights = weights.data.clamp(-MAX_HIDDEN_WEIGHT, MAX_HIDDEN_WEIGHT).mul(HIDDEN_WEIGHT_SCALE).round().to(torch.int8)
47-
return (biases, weights)
48-
49-
5041
def quant_output(biases, weights):
5142
biases = biases.data.mul(OUTPUT_BIAS_SCALE).round().to(torch.int32)
5243
weights = weights.data.clamp(-MAX_OUTPUT_WEIGHT, MAX_OUTPUT_WEIGHT).mul(OUTPUT_WEIGHT_SCALE).round().to(torch.int8)
@@ -63,16 +54,12 @@ def quantization(source, target):
6354

6455
# Perform quantization
6556
input = quant_input(nnue.input.bias, nnue.input.weight)
66-
linear1 = quant_linear(nnue.l1.bias, nnue.l1.weight)
67-
linear2 = quant_linear(nnue.l2.bias, nnue.l2.weight)
6857
output = quant_output(nnue.output.bias, nnue.output.weight)
6958

7059
# Write quantized layers
7160
outbuffer = bytearray()
7261
write_header(outbuffer, NNUE_FORMAT_VERSION)
7362
write_input(outbuffer, input[0], input[1])
74-
write_layer(outbuffer, linear1[0], linear1[1])
75-
write_layer(outbuffer, linear2[0], linear2[1])
7663
write_layer(outbuffer, output[0], output[1])
7764
with open(target, 'wb') as f:
7865
f.write(outbuffer)

0 commit comments

Comments
 (0)