-
Notifications
You must be signed in to change notification settings - Fork 26.3k
use logsigmoid at multilabel_soft_margin_loss, and change output from shape=(N, C)to (N,) #9965
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
use logsigmoid at multilabel_soft_margin_loss, and change output from shape=(N, C)to (N,) #9965
Conversation
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
weiyangfb has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
torch/nn/functional.py
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
soumith
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
approved. change comment.
torch/nn/functional.py
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
weiyangfb has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
1376a8b to
210f88e
Compare
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
weiyangfb has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
torch/nn/functional.py
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/functional.py
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
weiyangfb has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
weiyangfb has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
9ea862e to
5a10812
Compare
|
Why are we changing the behavior of all these losses? |
5a10812 to
c2b71e0
Compare
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
weiyangfb has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
torch/nn/functional.py
Outdated
| return binary_cross_entropy(input, target, weight, None, None, reduction) | ||
|
|
||
| loss = -(target * logsigmoid(input) + (1 - target) * logsigmoid(-input)) | ||
| loss.sum(dim=0) # only return N loss values |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
@li-roy I made changes accordingly, maybe take a look when you get a chance? |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
weiyangfb has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
li-roy
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good for the most part, just need to be consistent in the behavior.
| loss = loss.sum(dim=1) # only return N loss values | ||
|
|
||
| if reduction == 'none': | ||
| return loss |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
If you're planning to go with averaging across classes, don't forget to change the formula in the doc in torch/nn/modules/loss.py |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
weiyangfb has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
… shape=(N, C)to (N,) (pytorch#9965) Summary: - fixes pytorch#9141, pytorch#9301 - use logsigmoid at multilabel_soft_margin_loss to make it more stable (NOT fixing legacy MultiLabelSoftMarginCriterion) - return (N) instead of (N, C) to match the same behavior as MultiMarginLoss - Note that with this PR, the following behavior is expected: ``` loss = F.multilabel_soft_margin_loss(outputs, labels, reduction='none') loss_mean = F.multilabel_soft_margin_loss(outputs, labels, reduction='elementwise_mean') loss_sum = F.multilabel_soft_margin_loss(outputs, labels, reduction='sum') loss.sum() == loss_sum # True loss.mean() == loss_mean # True ``` Pull Request resolved: pytorch#9965 Differential Revision: D9038402 Pulled By: weiyangfb fbshipit-source-id: 0fa94c7b3cd370ea62bd6333f1a0e9bd0b8ccbb9
| input = torch.sigmoid(input) | ||
| return binary_cross_entropy(input, target, weight, None, None, reduction) | ||
|
|
||
| loss = -(target * logsigmoid(input) + (1 - target) * logsigmoid(-input)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this could probably use torch.lerp (yet AFAIK it fails export to ONNX and has NaN problems: #71701)
Uh oh!
There was an error while loading. Please reload this page.