We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 1308a12 commit 24098d7Copy full SHA for 24098d7
trainer/train.py
@@ -64,7 +64,7 @@ def train(model: torch.nn.Module, optimizer: torch.optim.Optimizer, dataloader:
64
optimizer.zero_grad()
65
prediction = model(batch)
66
67
- loss = torch.mean((prediction - batch.target) ** 2)
+ loss = torch.mean(torch.abs(prediction - batch.target) ** 2.5)
68
loss.backward()
69
optimizer.step()
70
model.clamp_weights()
0 commit comments