@@ -31,6 +31,39 @@ def forward(self, batch: Batch):
3131
3232 return torch .sigmoid (self .output_layer (hidden_features ))
3333
34+ def eval (self , fen , device ):
35+ fen = fen .split (" " )[0 ]
36+ stm_features_dense_tensor = torch .zeros (768 , device = device )
37+ nstm_features_dense_tensor = torch .zeros (768 , device = device )
38+
39+ for rank_idx , rank in enumerate (fen .split ('/' )):
40+ file_idx = 0
41+ for char in rank :
42+ if char .isdigit ():
43+ file_idx += int (char )
44+ else :
45+ sq = 8 * (7 - rank_idx ) + file_idx
46+ piece_type = {'p' : 0 , 'n' : 1 , 'b' : 2 , 'r' : 3 , 'q' : 4 , 'k' : 5 }[char .lower ()]
47+
48+ is_black_piece = char .islower ()
49+ piece_color = 1 if is_black_piece else 0
50+
51+ stm_features_dense_tensor [piece_color * 384 + piece_type * 64 + sq ] = 1
52+ nstm_features_dense_tensor [(1 - piece_color ) * 384 + piece_type * 64 + (sq ^ 56 )] = 1
53+
54+ file_idx += 1
55+
56+ board_stm = stm_features_dense_tensor .to_dense ()
57+ board_nstm = nstm_features_dense_tensor .to_dense ()
58+
59+ stm_perspective = self .feature_transformer (board_stm )
60+ nstm_perspective = self .feature_transformer (board_nstm )
61+
62+ hidden_features = torch .cat ((stm_perspective , nstm_perspective ))
63+ hidden_features = self .screlu (hidden_features )
64+
65+ print (int ((torch .special .logit (torch .sigmoid (self .output_layer (hidden_features ))) * 400 ).item ()))
66+
3467 def clamp_weights (self ):
3568 self .feature_transformer .weight .data .clamp_ (- 1.27 , 1.27 )
3669 self .output_layer .weight .data .clamp_ (- 1.27 , 1.27 )
0 commit comments