Skip to content

Commit a59f506

Browse files
committed
Add print of startpos eval
1 parent cc221fb commit a59f506

File tree

2 files changed

+35
-0
lines changed

2 files changed

+35
-0
lines changed

trainer/model.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,39 @@ def forward(self, batch: Batch):
3131

3232
return torch.sigmoid(self.output_layer(hidden_features))
3333

34+
def eval(self, fen, device):
35+
fen = fen.split(" ")[0]
36+
stm_features_dense_tensor = torch.zeros(768, device=device)
37+
nstm_features_dense_tensor = torch.zeros(768, device=device)
38+
39+
for rank_idx, rank in enumerate(fen.split('/')):
40+
file_idx = 0
41+
for char in rank:
42+
if char.isdigit():
43+
file_idx += int(char)
44+
else:
45+
sq = 8 * (7 - rank_idx) + file_idx
46+
piece_type = {'p': 0, 'n': 1, 'b': 2, 'r': 3, 'q': 4, 'k': 5}[char.lower()]
47+
48+
is_black_piece = char.islower()
49+
piece_color = 1 if is_black_piece else 0
50+
51+
stm_features_dense_tensor[piece_color * 384 + piece_type * 64 + sq] = 1
52+
nstm_features_dense_tensor[(1 - piece_color) * 384 + piece_type * 64 + (sq ^ 56)] = 1
53+
54+
file_idx += 1
55+
56+
board_stm = stm_features_dense_tensor.to_dense()
57+
board_nstm = nstm_features_dense_tensor.to_dense()
58+
59+
stm_perspective = self.feature_transformer(board_stm)
60+
nstm_perspective = self.feature_transformer(board_nstm)
61+
62+
hidden_features = torch.cat((stm_perspective, nstm_perspective))
63+
hidden_features = self.screlu(hidden_features)
64+
65+
print(int((torch.special.logit(torch.sigmoid(self.output_layer(hidden_features))) * 400).item()))
66+
3467
def clamp_weights(self):
3568
self.feature_transformer.weight.data.clamp_(-1.27, 1.27)
3669
self.output_layer.weight.data.clamp_(-1.27, 1.27)

trainer/train.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ def train(model: torch.nn.Module, optimizer: torch.optim.Optimizer, dataloader:
5757
iterations = 0
5858
fens = 0
5959

60+
model.eval("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR", device)
61+
6062
quantize(model, f"network/nnue_{epoch}_scaled.bin")
6163

6264
optimizer.zero_grad()

0 commit comments

Comments
 (0)