Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions aten/src/ATen/native/Distance.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#include "ATen/ATen.h"
#include "ATen/NativeFunctions.h"


namespace at { namespace native {

Tensor pairwise_distance(const Tensor& x1, const Tensor& x2, double p, double eps, bool keepdim) {
return norm(x1 - x2 + eps, p, 1, keepdim);
}
}} // namespace at::native
17 changes: 17 additions & 0 deletions aten/src/ATen/native/Loss.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,21 @@ Tensor hinge_embedding_loss(const Tensor& self, const Tensor& target, double mar
return output;
}

Tensor triplet_margin_loss(const Tensor& anchor, const Tensor& positive, const Tensor& negative, double margin,
double p, double eps, bool swap, bool size_average, bool reduce) {
auto dist_pos = at::pairwise_distance(anchor, positive, p, eps);
auto dist_neg = at::pairwise_distance(anchor, negative, p, eps);
if (swap) {
auto dist_swap = at::pairwise_distance(positive, negative, p, eps);
dist_neg = at::min(dist_neg, dist_swap);
}
auto output = at::clamp_min(margin + dist_pos - dist_neg, 0);

if (reduce && size_average) {
return output.sum() / output.numel();
} else if (reduce) {
return output.sum();
}
return output;
}
}} // namespace at::native
6 changes: 6 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,9 @@
- func: ones_like(Tensor self, *, Type dtype) -> Tensor
variants: function

- func: pairwise_distance(Tensor x1, Tensor x2, double p=2, double eps=1e-6, bool keepdim=false) -> Tensor
variants: function

- func: permute(Tensor self, IntList dims) -> Tensor
variants: method # This is method-only to match the previous tensor API. In the future we could make this a function too.

Expand Down Expand Up @@ -360,6 +363,9 @@
- func: transpose_(Tensor self, int64_t dim0, int64_t dim1) -> Tensor
variants: method

- func: triplet_margin_loss(Tensor anchor, Tensor positive, Tensor negative, double margin=1.0, double p=2, double eps=1e-6, bool swap=false, bool size_average=true, bool reduce=true) -> Tensor
variants: function

- func: t_(Tensor self) -> Tensor
variants: method

Expand Down
17 changes: 17 additions & 0 deletions test/common_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,22 @@ def _cos(a, b):
return output


def tripletmarginloss_reference(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, swap=False,
size_average=True, reduce=True):
d_p = torch.pairwise_distance(anchor, positive, p, eps)
d_n = torch.pairwise_distance(anchor, negative, p, eps)
if swap:
d_s = torch.pairwise_distance(positive, negative, p, eps)
d_n = torch.min(d_n, d_s)

output = torch.clamp(margin + d_p - d_n, min=0.0)
if reduce and size_average:
return output.mean()
elif reduce:
return output.sum()
return output


loss_reference_fns = {
'KLDivLoss': kldivloss_reference,
'NLLLoss': nllloss_reference,
Expand All @@ -433,6 +449,7 @@ def _cos(a, b):
'SoftMarginLoss': softmarginloss_reference,
'MultiMarginLoss': multimarginloss_reference,
'CosineEmbeddingLoss': cosineembeddingloss_reference,
'TripletMarginLoss': tripletmarginloss_reference,
}


Expand Down
36 changes: 29 additions & 7 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3803,18 +3803,40 @@ def test_cosine_embedding_loss_margin_no_reduce(self):
loss_reference_fns['CosineEmbeddingLoss'](input1, input2, target, margin=0.5))

def test_triplet_margin_loss(self):
input1 = Variable(torch.randn(4, 4), requires_grad=True)
input2 = Variable(torch.randn(4, 4), requires_grad=True)
input3 = Variable(torch.randn(4, 4), requires_grad=True)
input1 = Variable(torch.randn(5, 10), requires_grad=True)
input2 = Variable(torch.randn(5, 10), requires_grad=True)
input3 = Variable(torch.randn(5, 10), requires_grad=True)
self.assertTrue(gradcheck(lambda x1, x2, x3: F.triplet_margin_loss(
x1, x2, x3), (input1, input2, input3)))
self.assertEqual(F.triplet_margin_loss(input1, input2, input3),
loss_reference_fns['TripletMarginLoss'](input1, input2, input3))

def test_triplet_margin_swap_loss(self):
input1 = Variable(torch.randn(4, 4), requires_grad=True)
input2 = Variable(torch.randn(4, 4), requires_grad=True)
input3 = Variable(torch.randn(4, 4), requires_grad=True)
def test_triplet_margin_loss_swap(self):
input1 = Variable(torch.randn(5, 10), requires_grad=True)
input2 = Variable(torch.randn(5, 10), requires_grad=True)
input3 = Variable(torch.randn(5, 10), requires_grad=True)
self.assertTrue(gradcheck(lambda x1, x2, x3: F.triplet_margin_loss(
x1, x2, x3, swap=True), (input1, input2, input3)))
self.assertEqual(F.triplet_margin_loss(input1, input2, input3, swap=True),
loss_reference_fns['TripletMarginLoss'](input1, input2, input3, swap=True))

def test_triplet_margin_loss_no_reduce(self):
input1 = Variable(torch.randn(5, 10), requires_grad=True)
input2 = Variable(torch.randn(5, 10), requires_grad=True)
input3 = Variable(torch.randn(5, 10), requires_grad=True)
self.assertTrue(gradcheck(lambda x1, x2, x3: F.triplet_margin_loss(
x1, x2, x3, reduce=False), (input1, input2, input3)))
self.assertEqual(F.triplet_margin_loss(input1, input2, input3, reduce=False),
loss_reference_fns['TripletMarginLoss'](input1, input2, input3, reduce=False))

def test_triplet_margin_loss_swap_no_reduce(self):
input1 = Variable(torch.randn(5, 10), requires_grad=True)
input2 = Variable(torch.randn(5, 10), requires_grad=True)
input3 = Variable(torch.randn(5, 10), requires_grad=True)
self.assertTrue(gradcheck(lambda x1, x2, x3: F.triplet_margin_loss(
x1, x2, x3, swap=True, reduce=False), (input1, input2, input3)))
self.assertEqual(F.triplet_margin_loss(input1, input2, input3, swap=True, reduce=False),
loss_reference_fns['TripletMarginLoss'](input1, input2, input3, swap=True, reduce=False))

def test_cosine_similarity(self):
input1 = Variable(torch.randn(4, 4), requires_grad=True)
Expand Down
92 changes: 10 additions & 82 deletions torch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1936,35 +1936,11 @@ def pad(input, pad, mode='constant', value=0):

# distance

def pairwise_distance(x1, x2, p=2, eps=1e-6):
def pairwise_distance(x1, x2, p=2, eps=1e-6, keepdim=False):
r"""
Computes the batchwise pairwise distance between vectors v1,v2:

.. math ::
\Vert x \Vert _p := \left( \sum_{i=1}^n \vert x_i \vert ^ p \right) ^ {1/p}

Args:
x1: first input tensor
x2: second input tensor
p: the norm degree. Default: 2
eps (float, optional): Small value to avoid division by zero. Default: 1e-6

Shape:
- Input: :math:`(N, D)` where `D = vector dimension`
- Output: :math:`(N, 1)`

Example::

>>> input1 = torch.randn(100, 128)
>>> input2 = torch.randn(100, 128)
>>> output = F.pairwise_distance(input1, input2, p=2)
>>> output.backward()
See :class:`torch.nn.PairwiseDistance` for details
"""
assert x1.size() == x2.size(), "Input sizes must be equal."
assert x1.dim() == 2, "Input must be a 2D matrix."
diff = torch.abs(x1 - x2)
out = torch.pow(diff + eps, p).sum(dim=1, keepdim=True)

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

return torch.pow(out, 1. / p)
return torch._C._VariableFunctions.pairwise_distance(x1, x2, p, eps, keepdim)


def cosine_similarity(x1, x2, dim=1, eps=1e-8):
Expand Down Expand Up @@ -1997,61 +1973,13 @@ def cosine_similarity(x1, x2, dim=1, eps=1e-8):
return w12 / (w1 * w2).clamp(min=eps)


def triplet_margin_loss(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, swap=False):
r"""Creates a criterion that measures the triplet loss given an input
tensors x1, x2, x3 and a margin with a value greater than 0.
This is used for measuring a relative similarity between samples. A triplet
is composed by `a`, `p` and `n`: anchor, positive examples and negative
example respectively. The shape of all input variables should be
:math:`(N, D)`.

The distance swap is described in detail in the paper `Learning shallow
convolutional feature descriptors with triplet losses`_ by
V. Balntas, E. Riba et al.

.. math::
L(a, p, n) = \frac{1}{N} \left( \sum_{i=1}^N \max \{d(a_i, p_i) - d(a_i, n_i) + {\rm margin}, 0\} \right)

where :math:`d(x_i, y_i) = \left\lVert {\bf x}_i - {\bf y}_i \right\rVert_p`.

Args:
anchor: anchor input tensor
positive: positive input tensor
negative: negative input tensor
margin: the margin value. Default: 1
p: the norm degree. Default: 2
eps: small epsilon value to avoid numerical issues. Default: 1e-6
swap: compute distance swap. Default: ``False``

Shape:
- Input: :math:`(N, D)` where `D = vector dimension`
- Output: :math:`(N, 1)`

Example::

>>> input1 = torch.randn(100, 128)
>>> input2 = torch.randn(100, 128)
>>> input3 = torch.randn(100, 128)
>>> output = F.triplet_margin_loss(input1, input2, input3, p=2)
>>> output.backward()

.. _Learning shallow convolutional feature descriptors with triplet losses:
http://www.iis.ee.ic.ac.uk/%7Evbalnt/shallow_descr/TFeat_paper.pdf
"""
assert anchor.size() == positive.size(), "Input sizes between positive and negative must be equal."
assert anchor.size() == negative.size(), "Input sizes between anchor and negative must be equal."
assert positive.size() == negative.size(), "Input sizes between positive and negative must be equal."
assert anchor.dim() == 2, "Input must be a 2D matrix."
assert margin > 0.0, 'Margin should be positive value.'
d_p = pairwise_distance(anchor, positive, p, eps)
d_n = pairwise_distance(anchor, negative, p, eps)
if swap:
d_s = pairwise_distance(positive, negative, p, eps)
d_n = torch.min(d_n, d_s)

dist_hinge = torch.clamp(margin + d_p - d_n, min=0.0)
loss = torch.mean(dist_hinge)
return loss
def triplet_margin_loss(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, swap=False, size_average=True,
reduce=True):
r"""
See :class:`torch.nn.TripletMarginLoss` for details
"""
return torch._C._VariableFunctions.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap,
size_average, reduce)


def normalize(input, p=2, dim=1, eps=1e-12):
Expand Down
9 changes: 6 additions & 3 deletions torch/nn/modules/distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@ class PairwiseDistance(Module):
p (real): the norm degree. Default: 2
eps (float, optional): Small value to avoid division by zero.
Default: 1e-6
keepdim (bool, optional): Determines whether or not to keep the batch dimension.
Default: False

Shape:
- Input1: :math:`(N, D)` where `D = vector dimension`
- Input2: :math:`(N, D)`, same shape as the Input1
- Output: :math:`(N, 1)`
- Output: :math:`(N)`. If :attr:`keepdim` is ``False``, then :math:`(N, 1)`.

Examples::

Expand All @@ -27,13 +29,14 @@ class PairwiseDistance(Module):
>>> input2 = torch.randn(100, 128)
>>> output = pdist(input1, input2)
"""
def __init__(self, p=2, eps=1e-6):
def __init__(self, p=2, eps=1e-6, keepdim=False):
super(PairwiseDistance, self).__init__()
self.norm = p
self.eps = eps
self.keepdim = keepdim

def forward(self, x1, x2):
return F.pairwise_distance(x1, x2, self.norm, self.eps)
return F.pairwise_distance(x1, x2, self.norm, self.eps, self.keepdim)


class CosineSimilarity(Module):
Expand Down
36 changes: 24 additions & 12 deletions torch/nn/modules/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -921,7 +921,7 @@ def forward(self, input, target):
self.size_average, self.reduce)


class TripletMarginLoss(Module):
class TripletMarginLoss(_Loss):
r"""Creates a criterion that measures the triplet loss given an input
tensors x1, x2, x3 and a margin with a value greater than 0.
This is used for measuring a relative similarity between samples. A triplet
Expand All @@ -933,20 +933,31 @@ class TripletMarginLoss(Module):
convolutional feature descriptors with triplet losses`_ by
V. Balntas, E. Riba et al.

The loss function for each sample in the mini-batch is:

.. math::
L(a, p, n) = \frac{1}{N} \left( \sum_{i=1}^N \max \{d(a_i, p_i) - d(a_i, n_i) + {\rm margin}, 0\} \right)
L(a, p, n) = \max \{d(a_i, p_i) - d(a_i, n_i) + {\rm margin}, 0\}

where :math:`d(x_i, y_i) = \left\lVert {\bf x}_i - {\bf y}_i \right\rVert_p`.

Args:
anchor: anchor input tensor
positive: positive input tensor
negative: negative input tensor
p: the norm degree. Default: 2
margin (float, optional): Default: `1`.
p (int, optional): The norm degree for pairwise distance. Default: `2`.
swap (float, optional): The distance swap is described in detail in the paper
`Learning shallow convolutional feature descriptors with triplet losses` by

This comment was marked as off-topic.

This comment was marked as off-topic.

V. Balntas, E. Riba et al. Default: ``False``.
size_average (bool, optional): By default, the losses are averaged over
observations for each minibatch. However, if the field :attr:`size_average`
is set to ``False``, the losses are instead summed for each minibatch.
Default: ``True``
reduce (bool, optional): By default, the losses are averaged or summed over
observations for each minibatch depending on :attr:`size_average`. When
:attr:`reduce` is ``False``, returns a loss per batch element instead and
ignores :attr:`size_average`. Default: ``True``

Shape:
- Input: :math:`(N, D)` where `D = vector dimension`
- Output: :math:`(N, 1)`
- Input: :math:`(N, D)` where `D` is the vector dimension.
- Output: scalar. If `reduce` is False, then `(N)`.

>>> triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2)
>>> input1 = torch.randn(100, 128, requires_grad=True)
Expand All @@ -959,16 +970,17 @@ class TripletMarginLoss(Module):
http://www.iis.ee.ic.ac.uk/%7Evbalnt/shallow_descr/TFeat_paper.pdf
"""

def __init__(self, margin=1.0, p=2, eps=1e-6, swap=False):
super(TripletMarginLoss, self).__init__()
def __init__(self, margin=1.0, p=2, eps=1e-6, swap=False, size_average=True, reduce=True):

This comment was marked as off-topic.

This comment was marked as off-topic.

super(TripletMarginLoss, self).__init__(size_average)
self.margin = margin
self.p = p
self.eps = eps
self.swap = swap
self.reduce = reduce

def forward(self, anchor, positive, negative):
return F.triplet_margin_loss(anchor, positive, negative, self.margin,
self.p, self.eps, self.swap)
return F.triplet_margin_loss(anchor, positive, negative, self.margin, self.p,
self.eps, self.swap, self.size_average, self.reduce)

# TODO: L1HingeEmbeddingCriterion
# TODO: MSECriterion weight
Expand Down