Skip to content

Commit 24098d7

Browse files
committed
MPE 2.5
1 parent 1308a12 commit 24098d7

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

trainer/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def train(model: torch.nn.Module, optimizer: torch.optim.Optimizer, dataloader:
6464
optimizer.zero_grad()
6565
prediction = model(batch)
6666

67-
loss = torch.mean((prediction - batch.target) ** 2)
67+
loss = torch.mean(torch.abs(prediction - batch.target) ** 2.5)
6868
loss.backward()
6969
optimizer.step()
7070
model.clamp_weights()

0 commit comments

Comments
 (0)