Skip to content

Commit 9540f6c

Browse files
CarMirandafacebook-github-bot
authored andcommitted
Soft Margin loss (#27660)
Summary: In accordance with #25883, I added the `SoftMarginLoss` module and `soft_margin_loss` functional. Pull Request resolved: #27660 Differential Revision: D17958325 Pulled By: yf225 fbshipit-source-id: c14422765e6e1fdabf6c9687080e6d5ff490d300
1 parent c67d353 commit 9540f6c

File tree

6 files changed

+118
-1
lines changed

6 files changed

+118
-1
lines changed

test/cpp/api/functional.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,19 @@ TEST_F(FunctionalTest, CosineSimilarity) {
7373
ASSERT_TRUE(output.allclose(expected, 1e-04));
7474
}
7575

76+
TEST_F(FunctionalTest, SoftMarginLossDefaultOptions) {
77+
auto input = torch::tensor({2., 4., 1., 3.}, torch::requires_grad());
78+
auto target = torch::tensor({-1., 1., 1., -1.}, torch::kFloat);
79+
auto output =
80+
F::soft_margin_loss(input, target);
81+
auto expected = torch::tensor({1.3767317}, torch::kFloat);
82+
auto s = output.sum();
83+
s.backward();
84+
85+
ASSERT_TRUE(output.allclose(expected));
86+
ASSERT_EQ(input.sizes(), input.grad().sizes());
87+
}
88+
7689
TEST_F(FunctionalTest, MultiLabelSoftMarginLossDefaultOptions) {
7790
auto input = torch::tensor({{0., 2., 2., 0.}, {2., 1., 0., 1.}}, torch::requires_grad());
7891
auto target = torch::tensor({{0., 0., 1., 0.}, {1., 0., 1., 1.}}, torch::kFloat);
@@ -86,6 +99,19 @@ TEST_F(FunctionalTest, MultiLabelSoftMarginLossDefaultOptions) {
8699
ASSERT_EQ(input.sizes(), input.grad().sizes());
87100
}
88101

102+
TEST_F(FunctionalTest, SoftMarginLossNoReduction) {
103+
auto input = torch::tensor({2., 4., 1., 3.}, torch::requires_grad());
104+
auto target = torch::tensor({-1., 1., 1., -1.}, torch::kFloat);
105+
auto output =
106+
F::soft_margin_loss(input, target, torch::Reduction::None);
107+
auto expected = torch::tensor({2.1269281, 0.01814993, 0.3132617, 3.0485873}, torch::kFloat);
108+
auto s = output.sum();
109+
s.backward();
110+
111+
ASSERT_TRUE(output.allclose(expected));
112+
ASSERT_EQ(input.sizes(), input.grad().sizes());
113+
}
114+
89115
TEST_F(FunctionalTest, MultiLabelSoftMarginLossWeightedNoReduction) {
90116
auto input = torch::tensor({{0., 2., 2., 0.}, {2., 1., 0., 1.}}, torch::requires_grad());
91117
auto target = torch::tensor({{0., 0., 1., 0.}, {1., 0., 1., 1.}}, torch::kFloat);

test/cpp/api/modules.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1096,6 +1096,19 @@ TEST_F(ModulesTest, CosineSimilarity) {
10961096
ASSERT_EQ(input1.sizes(), input1.grad().sizes());
10971097
}
10981098

1099+
TEST_F(ModulesTest, SoftMarginLossDefaultOptions) {
1100+
SoftMarginLoss loss;
1101+
auto input = torch::tensor({2., 4., 1., 3.}, torch::requires_grad());
1102+
auto target = torch::tensor({-1., 1., 1., -1.}, torch::kFloat);
1103+
auto output = loss->forward(input, target);
1104+
auto expected = torch::tensor({1.3767317}, torch::kFloat);
1105+
auto s = output.sum();
1106+
s.backward();
1107+
1108+
ASSERT_TRUE(output.allclose(expected));
1109+
ASSERT_EQ(input.sizes(), input.grad().sizes());
1110+
}
1111+
10991112
TEST_F(ModulesTest, MultiLabelSoftMarginLossDefaultOptions) {
11001113
MultiLabelSoftMarginLoss loss;
11011114
auto input = torch::tensor({{0., 2., 2., 0.}, {2., 1., 0., 1.}}, torch::requires_grad());
@@ -1109,6 +1122,19 @@ TEST_F(ModulesTest, MultiLabelSoftMarginLossDefaultOptions) {
11091122
ASSERT_EQ(input.sizes(), input.grad().sizes());
11101123
}
11111124

1125+
TEST_F(ModulesTest, SoftMarginLossNoReduction) {
1126+
SoftMarginLoss loss(torch::Reduction::None);
1127+
auto input = torch::tensor({2., 4., 1., 3.}, torch::requires_grad());
1128+
auto target = torch::tensor({-1., 1., 1., -1.}, torch::kFloat);
1129+
auto output = loss->forward(input, target);
1130+
auto expected = torch::tensor({2.1269281, 0.01814993, 0.3132617, 3.0485873}, torch::kFloat);
1131+
auto s = output.sum();
1132+
s.backward();
1133+
1134+
ASSERT_TRUE(output.allclose(expected));
1135+
ASSERT_EQ(input.sizes(), input.grad().sizes());
1136+
}
1137+
11121138
TEST_F(ModulesTest, MultiLabelSoftMarginLossWeightedNoReduction) {
11131139
auto input = torch::tensor({{0., 2., 2., 0.}, {2., 1., 0., 1.}}, torch::requires_grad());
11141140
auto target = torch::tensor({{0., 0., 1., 0.}, {1., 0., 1., 1.}}, torch::kFloat);
@@ -1733,6 +1759,10 @@ TEST_F(ModulesTest, PrettyPrintMultiLabelSoftMarginLoss) {
17331759
ASSERT_EQ(c10::str(MultiLabelSoftMarginLoss()), "torch::nn::MultiLabelSoftMarginLoss()");
17341760
}
17351761

1762+
TEST_F(ModulesTest, PrettyPrintSoftMarginLoss) {
1763+
ASSERT_EQ(c10::str(SoftMarginLoss()), "torch::nn::SoftMarginLoss()");
1764+
}
1765+
17361766
TEST_F(ModulesTest, PrettyPrintCosineSimilarity) {
17371767
ASSERT_EQ(
17381768
c10::str(CosineSimilarity()),

torch/csrc/api/include/torch/nn/functional/loss.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,13 @@ inline Tensor cosine_embedding_loss(
4545
input1, input2, target, options.margin(), options.reduction());
4646
}
4747

48+
inline Tensor soft_margin_loss(
49+
const Tensor& input,
50+
const Tensor& target,
51+
const SoftMarginLossOptions& options = {}) {
52+
return torch::soft_margin_loss(input, target, options.reduction());
53+
}
54+
4855
inline Tensor multilabel_soft_margin_loss(
4956
const Tensor& input,
5057
const Tensor& target,

torch/csrc/api/include/torch/nn/modules/loss.h

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,14 +125,39 @@ TORCH_MODULE(CosineEmbeddingLoss);
125125

126126
// ============================================================================
127127

128+
/// Creates a criterion that optimizes a two-class classification
129+
/// logistic loss between input tensor :math:`x` and target tensor :math:`y`
130+
/// (containing 1 or -1).
131+
struct TORCH_API SoftMarginLossImpl : public Cloneable<SoftMarginLossImpl> {
132+
explicit SoftMarginLossImpl(const SoftMarginLossOptions& options_ = {});
133+
134+
/// Pretty prints the `SoftMarginLoss` module into the given `stream`.
135+
void pretty_print(std::ostream& stream) const override;
136+
137+
void reset() override;
138+
139+
Tensor forward(const Tensor& input, const Tensor& target);
140+
141+
/// The options with which this `Module` was constructed.
142+
SoftMarginLossOptions options;
143+
};
144+
145+
/// A `ModuleHolder` subclass for `SoftMarginLossImpl`.
146+
/// See the documentation for `SoftMarginLossImpl` class to learn what methods it
147+
/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's
148+
/// module storage semantics.
149+
TORCH_MODULE(SoftMarginLoss);
150+
151+
// ============================================================================
152+
128153
/// Creates a criterion that optimizes a multi-label one-versus-all
129154
/// loss based on max-entropy, between input :math:`x` and target :math:`y` of size
130155
/// :math:`(N, C)`.
131156
struct TORCH_API MultiLabelSoftMarginLossImpl : public Cloneable<MultiLabelSoftMarginLossImpl> {
132157
explicit MultiLabelSoftMarginLossImpl(
133158
const MultiLabelSoftMarginLossOptions& options_ = {});
134159

135-
/// Pretty prints the `L1Loss` module into the given `stream`.
160+
/// Pretty prints the `MultiLabelSoftMarginLoss` module into the given `stream`.
136161
void pretty_print(std::ostream& stream) const override;
137162

138163
void reset() override;

torch/csrc/api/include/torch/nn/options/loss.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,20 @@ struct TORCH_API CosineEmbeddingLossOptions {
6161

6262
// ============================================================================
6363

64+
/// Options for a soft margin loss functional and module.
65+
struct TORCH_API SoftMarginLossOptions {
66+
SoftMarginLossOptions(torch::Reduction::Reduction reduction = torch::Reduction::Mean)
67+
: reduction_(reduction) {}
68+
69+
/// Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'.
70+
/// 'none': no reduction will be applied, 'mean': the sum of the output will
71+
/// be divided by the number of elements in the output, 'sum': the output will
72+
/// be summed. Default: 'mean'
73+
TORCH_ARG(torch::Reduction::Reduction, reduction);
74+
};
75+
76+
// ============================================================================
77+
6478
/// Options for a multi-label soft margin loss functional and module.
6579
struct TORCH_API MultiLabelSoftMarginLossOptions {
6680
/// A manual rescaling weight given to each

torch/csrc/api/src/nn/modules/loss.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,5 +126,20 @@ Tensor TripletMarginLossImpl::forward(
126126
return F::triplet_margin_loss(anchor, positive, negative, options);
127127
}
128128

129+
// ============================================================================
130+
131+
SoftMarginLossImpl::SoftMarginLossImpl(
132+
const torch::nn::SoftMarginLossOptions& options_) : options(options_) {}
133+
134+
void SoftMarginLossImpl::reset() {}
135+
136+
void SoftMarginLossImpl::pretty_print(std::ostream& stream) const {
137+
stream << "torch::nn::SoftMarginLoss()";
138+
}
139+
140+
Tensor SoftMarginLossImpl::forward(const Tensor& input, const Tensor& target) {
141+
return F::soft_margin_loss(input, target, options);
142+
}
143+
129144
} // namespace nn
130145
} // namespace torch

0 commit comments

Comments
 (0)