Skip to content

Commit cabec10

Browse files
authored
Simplify
1 parent 9d688dc commit cabec10

File tree

3 files changed

+9
-43
lines changed

3 files changed

+9
-43
lines changed

trainer/batchloader.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,25 +30,17 @@ class Batch:
3030
class BatchLoader:
3131
def __init__(self, lib_path: str, files: list[bytes], batch_size: int, scale: float, wdl: float) -> None:
3232
self.parse_lib = None
33-
if not files:
34-
raise ValueError("The files list cannot be empty.")
35-
36-
try:
37-
self.parse_lib = ctypes.CDLL(lib_path)
38-
except OSError as e:
39-
raise Exception(f"Failed to load the library: {e}")
33+
if not files: raise ValueError("The files list cannot be empty.")
34+
try: self.parse_lib = ctypes.CDLL(lib_path)
35+
except OSError as e: raise Exception(f"Failed to load the library: {e}")
4036
self.load_parse_lib()
4137

42-
self.files = files
43-
self.file_index = 0
38+
self.files, self.file_index = files, 0
4439
self.batch = ctypes.c_void_p(self.parse_lib.batch_new(ctypes.c_uint32(batch_size), ctypes.c_float(scale), ctypes.c_float(wdl)))
45-
46-
if self.batch.value is None:
47-
raise Exception("Failed to create batch")
40+
if self.batch.value is None: raise Exception("Failed to create batch")
4841

4942
self.current_reader = ctypes.c_void_p(self.parse_lib.file_reader_new(ctypes.create_string_buffer(files[0])))
50-
if self.current_reader.value is None:
51-
raise Exception("Failed to create file reader")
43+
if self.current_reader.value is None: raise Exception("Failed to create file reader")
5244

5345
def next_batch(self, device: torch.device) -> tuple[bool, Batch]:
5446
new_epoch = False

trainer/model.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -24,22 +24,3 @@ def forward(self, batch: Batch):
2424
def clamp_weights(self):
2525
self.feature_transformer.weight.data.clamp_(-2.0, 2.0)
2626
self.output_layer.weight.data.clamp_(-2.0, 2.0)
27-
28-
29-
# 768 -> N -> 1
30-
class StmPerspectiveNetwork(torch.nn.Module):
31-
def __init__(self, feature_output_size: int):
32-
super().__init__()
33-
self.feature_transformer = torch.nn.Linear(768, feature_output_size)
34-
self.output_layer = torch.nn.Linear(feature_output_size, 1)
35-
36-
def forward(self, batch: Batch):
37-
stm_perspective = self.feature_transformer(batch.stm_sparse.to_dense())
38-
39-
hidden_features = torch.clamp(stm_perspective, 0, 1)
40-
41-
return torch.sigmoid(self.output_layer(hidden_features))
42-
43-
def clamp_weights(self):
44-
self.feature_transformer.weight.data.clamp_(-2.0, 2.0)
45-
self.output_layer.weight.data.clamp_(-2.0, 2.0)

trainer/train.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
from time import time
2-
32
import torch
4-
from batchloader import BatchLoader
53

4+
from batchloader import BatchLoader
65
from quantize import quantize
76

87

@@ -12,13 +11,7 @@ def print_epoch_stats(epoch, running_loss, iterations, fens, start_time, current
1211
.format(epoch, epoch_time, running_loss.item() / iterations, fens / epoch_time))
1312
print(message)
1413

15-
def train(
16-
model: torch.nn.Module,
17-
optimizer: torch.optim.Optimizer,
18-
dataloader: BatchLoader,
19-
epochs: int,
20-
device: torch.device,
21-
) -> None:
14+
def train(model: torch.nn.Module, optimizer: torch.optim.Optimizer, dataloader: BatchLoader, epochs: int, device: torch.device):
2215
running_loss = torch.zeros(1, device=device)
2316
epoch_start_time = time()
2417
iterations = 0
@@ -37,7 +30,7 @@ def train(
3730
epoch_start_time = current_time
3831
iterations = 0
3932
fens = 0
40-
33+
4134
quantize(model, f"network/nnue_{epoch}_scaled.bin")
4235

4336
optimizer.zero_grad()

0 commit comments

Comments
 (0)