-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Add pos_weight argument to nn.BCEWithLogitsLoss (#5660) #6856
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add pos_weight argument to nn.BCEWithLogitsLoss (#5660) #6856
Conversation
…tropy_with_logits (pytorch#5660) - Add an option to control precision/recall in imbalanced datasets - Add tests (but new_criterion_tests)
torch/nn/functional.py
Outdated
| target: Tensor of the same shape as input | ||
| weight (Tensor, optional): a manual rescaling weight | ||
| if provided it's repeated to match input tensor shape | ||
| pos_weight (Tensor, optional): a weight of positive examples. |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/modules/loss.py
Outdated
| weight (Tensor, optional): a manual rescaling weight given to the loss | ||
| of each batch element. If given, has to be a Tensor of size | ||
| "nbatch". | ||
| pos_weight (Tensor, optional): a weight of positive examples. |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
ezyang
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Accept assuming the doc issues are fixed
`pos_weight` was moved to the end because it is the last argument in both `nn.BCEWithLogitsLoss` and `binary_cross_entropy_with_logits`
| is ``False``, returns a loss per input/target element instead and ignores | ||
| :attr:`size_average`. Default: ``True`` | ||
| pos_weight (Tensor, optional): a weight of positive examples. | ||
| Must be a vector with length equal to the number of classes. |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| def __init__(self, weight=None, size_average=True, reduce=True, pos_weight=None): | ||
| super(BCEWithLogitsLoss, self).__init__(size_average, reduce) | ||
| self.register_buffer('weight', weight) | ||
| self.register_buffer('pos_weight', pos_weight) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
@pytorchbot retest this please |
4 similar comments
|
@pytorchbot retest this please |
|
@pytorchbot retest this please |
|
@pytorchbot retest this please |
|
@pytorchbot retest this please |
* upstream/master: (42 commits) [c10d] No default device for ProcessGroupGloo (pytorch#8888) Fix default values for affine= in the docstrings of InstanceNormXd (pytorch#8895) Stop making dynamic allocations of PinnedMemoryAllocator. (pytorch#8896) [C++ API] Rework optimization package (pytorch#8815) Mention MPICH_MAX_THREAD_SAFETY=multiple. (pytorch#8580) Unify isViewable, handle n-dimensional empty tensors. (pytorch#8883) Add pos_weight argument to nn.BCEWithLogitsLoss (pytorch#5660) (pytorch#6856) [build] Enable clang-specific warnings only when using clang (pytorch#8869) Fix cmake cudnn autodetection (pytorch#8891) [c10d] Fix link order for building C++ tests (pytorch#8889) directly add_subdirectory(nanopb) from torch CMakeLists (pytorch#8870) [C++ API] Bag of fixes (pytorch#8843) [build] Raise in cmake when seeing NVCC{9/9.1} + GCC6 combo (pytorch#8863) Create avg_pool1d in ATen (pytorch#8880) throw error when grid_sample is passed unsupported mode (pytorch#8884) Allow autograd to work even when the shape of values cannot be determined (pytorch#8641) Make at::Tensor::to() const (pytorch#8839) [auto] Update onnx to 458c521 - Fix typo (onnx/onnx#1143) onnx/onnx@458c521 [Caffe2] Fix gradient_check on in-place ops (pytorch#8828) Fix as_strided_backward (pytorch#8721) ...
Multiplier
1 + (pos_weight - 1) * targetis the only significant difference.Notes:
F.binary_cross_entropy/nn.BCELoss. (It uses implementation from torch._C.) I can add a straightforward and numerically unstable implementation to this function but I'm not sure if it is really needed.pos_weightas last argument to prevent errors in the code that doesn't use names for keyword arguments. But it looks quite ugly.P.S. I proposed these changes in #5660