Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3925,6 +3925,39 @@ def test_bce_with_logits_broadcasts_weights(self):

self.assertEqual(out1, out2)

def test_bce_with_logits_ones_in_pos_weights_are_the_same_as_none(self):
target = torch.rand(64, 4)
output = torch.rand(64, 4) - 0.5
pos_weight = torch.ones(64, 4)

self.assertEqual(nn.BCEWithLogitsLoss()(output, target),
nn.BCEWithLogitsLoss(pos_weight=pos_weight)(output, target))

def test_bce_with_logits_broadcasts_pos_weights(self):
target = torch.rand(64, 4)
output = torch.rand(64, 4) - 0.5
pos_weight = torch.rand(4)
out1 = nn.BCEWithLogitsLoss(pos_weight=pos_weight)(output, target)

pos_weight1 = pos_weight.expand(1, 4)
out2 = nn.BCEWithLogitsLoss(pos_weight=pos_weight1)(output, target)

pos_weight2 = pos_weight.expand(64, 4)
out3 = nn.BCEWithLogitsLoss(pos_weight=pos_weight2)(output, target)

self.assertEqual(out1, out2)
self.assertEqual(out1, out3)

def test_bce_with_logits_with_pos_weight_has_correct_grad_at_zero(self):
output = torch.zeros(3, 1, requires_grad=True)
target = torch.zeros(3, 1)
pos_weight = torch.ones(3, 1)
nn.BCEWithLogitsLoss(pos_weight=pos_weight, size_average=False)(output, target).backward()
expected_grad = torch.empty(3, 1).fill_(0.5)
grad = output.grad
print(grad)
self.assertEqual(grad, expected_grad)

def test_bce_loss_broadcasts_weights(self):
sigmoid = nn.Sigmoid()
target = torch.rand(16, 4)
Expand Down
11 changes: 9 additions & 2 deletions torch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1496,7 +1496,7 @@ def binary_cross_entropy(input, target, weight=None, size_average=True, reduce=T
return torch._C._nn.binary_cross_entropy(input, target, weight, size_average, reduce)


def binary_cross_entropy_with_logits(input, target, weight=None, size_average=True, reduce=True):
def binary_cross_entropy_with_logits(input, target, weight=None, size_average=True, reduce=True, pos_weight=None):
r"""Function that measures Binary Cross Entropy between target and output
logits.

Expand All @@ -1515,6 +1515,8 @@ def binary_cross_entropy_with_logits(input, target, weight=None, size_average=Tr
observations for each minibatch depending on :attr:`size_average`. When :attr:`reduce`
is ``False``, returns a loss per input/target element instead and ignores
:attr:`size_average`. Default: ``True``
pos_weight (Tensor, optional): a weight of positive examples.
Must be a vector with length equal to the number of classes.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.


Examples::

Expand All @@ -1527,7 +1529,12 @@ def binary_cross_entropy_with_logits(input, target, weight=None, size_average=Tr
raise ValueError("Target size ({}) must be the same as input size ({})".format(target.size(), input.size()))

max_val = (-input).clamp(min=0)
loss = input - input * target + max_val + ((-max_val).exp() + (-input - max_val).exp()).log()

if pos_weight is None:
loss = input - input * target + max_val + ((-max_val).exp() + (-input - max_val).exp()).log()
else:
log_weight = 1 + (pos_weight - 1) * target
loss = input - input * target + log_weight * (max_val + ((-max_val).exp() + (-input - max_val).exp()).log())

if weight is not None:
loss = loss * weight
Expand Down
33 changes: 23 additions & 10 deletions torch/nn/modules/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,20 @@ class BCEWithLogitsLoss(_Loss):
an auto-encoder. Note that the targets `t[i]` should be numbers
between 0 and 1.

It's possible to trade off recall and precision by adding weights to positive examples.
In this case the loss can be described as:
.. math::
\ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
l_n = - w_n \left[ p_n t_n \cdot \log \sigma(x_n)
+ (1 - t_n) \cdot \log (1 - \sigma(x_n)) \right],

where :math:`p_n` is the positive weight of class :math:`n`.
:math:`p_n > 1` increases the recall, :math:`p_n < 1` increases the precision.

For example, if a dataset contains 100 positive and 300 negative examples of a single class,
then `pos_weight` for the class should be equal to math:`\frac{300}{100}=3`.
The loss would act as if the dataset contains math:`3\times 100=300` positive examples.

Args:
weight (Tensor, optional): a manual rescaling weight given to the loss
of each batch element. If given, has to be a Tensor of size
Expand All @@ -470,6 +484,8 @@ class BCEWithLogitsLoss(_Loss):
observations for each minibatch depending on size_average. When reduce
is False, returns a loss per input/target element instead and ignores
size_average. Default: True
pos_weight (Tensor, optional): a weight of positive examples.
Must be a vector with length equal to the number of classes.

Shape:
- Input: :math:`(N, *)` where `*` means, any number of additional
Expand All @@ -484,20 +500,17 @@ class BCEWithLogitsLoss(_Loss):
>>> output = loss(input, target)
>>> output.backward()
"""
def __init__(self, weight=None, size_average=True, reduce=True):
def __init__(self, weight=None, size_average=True, reduce=True, pos_weight=None):
super(BCEWithLogitsLoss, self).__init__(size_average, reduce)
self.register_buffer('weight', weight)
self.register_buffer('pos_weight', pos_weight)

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.


def forward(self, input, target):
if self.weight is not None:
return F.binary_cross_entropy_with_logits(input, target,
self.weight,
self.size_average,
reduce=self.reduce)
else:
return F.binary_cross_entropy_with_logits(input, target,
size_average=self.size_average,
reduce=self.reduce)
return F.binary_cross_entropy_with_logits(input, target,
self.weight,
pos_weight=self.pos_weight,
size_average=self.size_average,
reduce=self.reduce)


class HingeEmbeddingLoss(_Loss):
Expand Down