-
Notifications
You must be signed in to change notification settings - Fork 26.3k
implement TripletMarginLoss as a native function #5680
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
|
||
| def __init__(self, margin=1.0, p=2, eps=1e-6, swap=False): | ||
| super(TripletMarginLoss, self).__init__() | ||
| def __init__(self, margin=1.0, p=2, eps=1e-6, swap=False, size_average=True, reduce=True): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
test/test_nn.py
Outdated
| constructor_args=(torch.rand(10),), | ||
| input_fn=lambda: torch.randn(5, 10), | ||
| target_fn=lambda: torch.rand(5, 10).mul(2).floor(), | ||
| reference_fn=lambda i, t, m: -((t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log()) * get_weight(m)).sum() / |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
aten/src/ATen/native/Distance.cpp
Outdated
| Tensor pairwise_distance(const Tensor& x1, const Tensor& x2, double p, double eps) { | ||
| auto diff = abs(x1 - x2); | ||
| auto out = pow(diff + eps, p).sum(1); | ||
| return pow(out, 1 / p); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
@pytorchbot retest this please |
zou3519
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code looks good to me. I had a few minor comments and a concern that we're subtly changing the behavior of pairwise_distance
| margin (float, optional): Default: `1`. | ||
| p (int, optional): The norm degree for pairwise distance. Default: `2`. | ||
| swap (float, optional): The distance swap is described in detail in the paper | ||
| `Learning shallow convolutional feature descriptors with triplet losses` by |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
|
||
| def __init__(self, margin=1.0, p=2, eps=1e-6, swap=False): | ||
| super(TripletMarginLoss, self).__init__() | ||
| def __init__(self, margin=1.0, p=2, eps=1e-6, swap=False, size_average=True, reduce=True): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| - func: ones_like(Tensor self, *, Type dtype) -> Tensor | ||
| variants: function | ||
|
|
||
| - func: pairwise_distance(Tensor x1, Tensor x2, double p, double eps) -> Tensor |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| - func: transpose_(Tensor self, int64_t dim0, int64_t dim1) -> Tensor | ||
| variants: method | ||
|
|
||
| - func: triplet_margin_loss(Tensor anchor, Tensor positive, Tensor negative, double margin, double p, double eps, bool swap, bool size_average, bool reduce) -> Tensor |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
aten/src/ATen/native/Distance.cpp
Outdated
|
|
||
| Tensor pairwise_distance(const Tensor& x1, const Tensor& x2, double p, double eps) { | ||
| auto diff = abs(x1 - x2 + eps); | ||
| return norm(diff, p, 1); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
aten/src/ATen/native/Distance.cpp
Outdated
| namespace at { namespace native { | ||
|
|
||
| Tensor pairwise_distance(const Tensor& x1, const Tensor& x2, double p, double eps) { | ||
| auto diff = abs(x1 - x2 + eps); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| assert x1.size() == x2.size(), "Input sizes must be equal." | ||
| assert x1.dim() == 2, "Input must be a 2D matrix." | ||
| diff = torch.abs(x1 - x2) | ||
| out = torch.pow(diff + eps, p).sum(dim=1, keepdim=True) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/functional.py
Outdated
| >>> input1 = torch.randn(100, 128) | ||
| >>> input2 = torch.randn(100, 128) | ||
| >>> output = F.pairwise_distance(input1, input2, p=2) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
zou3519
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Minor nit below
aten/src/ATen/native/Distance.cpp
Outdated
|
|
||
| namespace at { namespace native { | ||
|
|
||
| Tensor pairwise_distance(const Tensor& x1, const Tensor& x2, double p, double eps) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
zou3519
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yay! Approve approve approve
Benchmarks run on inputs of size (5,500) x 5000 iterations.
old:
forward [1.0334448860958219, 1.0698160571046174, 1.054812144022435]
backward [1.6976923400070518, 1.6186234278138727, 1.617640192154795]
double backward [3.2021269120741636, 3.165953448973596, 3.24088541790843]
new:
forward [0.8313469190616161, 0.8255189431365579, 0.8157028129789978]
backward [1.5887232730165124, 1.6056769208516926, 1.5178304109722376]
double backward [3.154403690015897, 3.1905114939436316, 3.2450901730917394]