Skip to content

Commit 142d17b

Browse files
committed
Add lr drop
1 parent 586058f commit 142d17b

File tree

3 files changed

+7
-2
lines changed

3 files changed

+7
-2
lines changed

trainer/config.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
"epochs": 10,
77
"batch_size": 16384,
88
"wdl": 0.5,
9+
"lr_drop_steps" : 4,
910
"scale": 400,
1011
"hidden_layer_size": 16
1112
}

trainer/main.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def load_config(config_path="config.json"):
2525
epochs = config.get("epochs")
2626
batch_size = config.get("batch_size", 16384)
2727
wdl = config.get("wdl", 0.5)
28+
lr_drop_steps = config.get("lr_drop_steps", 10)
2829
scale = config.get("scale")
2930

3031
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@@ -38,7 +39,7 @@ def load_config(config_path="config.json"):
3839

3940
start_time = time()
4041

41-
train(model, optimizer, dataloader, epochs, device)
42+
train(model, optimizer, dataloader, epochs, lr_drop_steps, device)
4243

4344
end_time = time()
4445
elapsed_time = end_time - start_time

trainer/train.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def print_epoch_stats(epoch, running_loss, iterations, fens, start_time, current
1111
.format(epoch, epoch_time, running_loss.item() / iterations, fens / epoch_time))
1212
print(message)
1313

14-
def train(model: torch.nn.Module, optimizer: torch.optim.Optimizer, dataloader: BatchLoader, epochs: int, device: torch.device):
14+
def train(model: torch.nn.Module, optimizer: torch.optim.Optimizer, dataloader: BatchLoader, epochs: int, lr_drop_steps: int, device: torch.device):
1515
running_loss = torch.zeros(1, device=device)
1616
epoch_start_time = time()
1717
iterations = 0
@@ -23,6 +23,9 @@ def train(model: torch.nn.Module, optimizer: torch.optim.Optimizer, dataloader:
2323
if new_epoch:
2424
epoch += 1
2525

26+
if epoch % lr_drop_steps == 0:
27+
optimizer.param_groups[0]["lr"] *= 0.1
28+
2629
current_time = time()
2730
print_epoch_stats(epoch, running_loss, iterations, fens, epoch_start_time, current_time)
2831

0 commit comments

Comments
 (0)