Skip to content

Commit 147c473

Browse files
committed
Add screlu activation function
1 parent cabec10 commit 147c473

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

trainer/model.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,22 @@
22

33
from batchloader import Batch
44

5+
class SCReLU(torch.nn.Module):
6+
def __init__(self, inplace=False):
7+
super().__init__()
8+
self.inplace = inplace
9+
10+
def forward(self, input):
11+
return torch.pow(torch.clamp(input, 0, 1), 2)
12+
513

614
# (768 -> N) x 2 -> 1
715
class PerspectiveNetwork(torch.nn.Module):
816
def __init__(self, feature_output_size: int):
917
super().__init__()
1018
self.feature_transformer = torch.nn.Linear(768, feature_output_size)
1119
self.output_layer = torch.nn.Linear(feature_output_size * 2, 1)
20+
self.screlu = SCReLU()
1221

1322
def forward(self, batch: Batch):
1423
board_stm = batch.stm_sparse.to_dense()
@@ -17,7 +26,8 @@ def forward(self, batch: Batch):
1726
stm_perspective = self.feature_transformer(board_stm)
1827
nstm_perspective = self.feature_transformer(board_nstm)
1928

20-
hidden_features = torch.clamp(torch.cat((stm_perspective, nstm_perspective), dim=1), 0, 1)
29+
hidden_features = torch.cat((stm_perspective, nstm_perspective), dim=1)
30+
hidden_features = self.screlu(hidden_features)
2131

2232
return torch.sigmoid(self.output_layer(hidden_features))
2333

0 commit comments

Comments
 (0)