@@ -1096,6 +1096,19 @@ TEST_F(ModulesTest, CosineSimilarity) {
10961096 ASSERT_EQ (input1.sizes (), input1.grad ().sizes ());
10971097}
10981098
1099+ TEST_F (ModulesTest, SoftMarginLossDefaultOptions) {
1100+ SoftMarginLoss loss;
1101+ auto input = torch::tensor ({2 ., 4 ., 1 ., 3 .}, torch::requires_grad ());
1102+ auto target = torch::tensor ({-1 ., 1 ., 1 ., -1 .}, torch::kFloat );
1103+ auto output = loss->forward (input, target);
1104+ auto expected = torch::tensor ({1.3767317 }, torch::kFloat );
1105+ auto s = output.sum ();
1106+ s.backward ();
1107+
1108+ ASSERT_TRUE (output.allclose (expected));
1109+ ASSERT_EQ (input.sizes (), input.grad ().sizes ());
1110+ }
1111+
10991112TEST_F (ModulesTest, MultiLabelSoftMarginLossDefaultOptions) {
11001113 MultiLabelSoftMarginLoss loss;
11011114 auto input = torch::tensor ({{0 ., 2 ., 2 ., 0 .}, {2 ., 1 ., 0 ., 1 .}}, torch::requires_grad ());
@@ -1109,6 +1122,19 @@ TEST_F(ModulesTest, MultiLabelSoftMarginLossDefaultOptions) {
11091122 ASSERT_EQ (input.sizes (), input.grad ().sizes ());
11101123}
11111124
1125+ TEST_F (ModulesTest, SoftMarginLossNoReduction) {
1126+ SoftMarginLoss loss (torch::Reduction::None);
1127+ auto input = torch::tensor ({2 ., 4 ., 1 ., 3 .}, torch::requires_grad ());
1128+ auto target = torch::tensor ({-1 ., 1 ., 1 ., -1 .}, torch::kFloat );
1129+ auto output = loss->forward (input, target);
1130+ auto expected = torch::tensor ({2.1269281 , 0.01814993 , 0.3132617 , 3.0485873 }, torch::kFloat );
1131+ auto s = output.sum ();
1132+ s.backward ();
1133+
1134+ ASSERT_TRUE (output.allclose (expected));
1135+ ASSERT_EQ (input.sizes (), input.grad ().sizes ());
1136+ }
1137+
11121138TEST_F (ModulesTest, MultiLabelSoftMarginLossWeightedNoReduction) {
11131139 auto input = torch::tensor ({{0 ., 2 ., 2 ., 0 .}, {2 ., 1 ., 0 ., 1 .}}, torch::requires_grad ());
11141140 auto target = torch::tensor ({{0 ., 0 ., 1 ., 0 .}, {1 ., 0 ., 1 ., 1 .}}, torch::kFloat );
@@ -1733,6 +1759,10 @@ TEST_F(ModulesTest, PrettyPrintMultiLabelSoftMarginLoss) {
17331759 ASSERT_EQ (c10::str (MultiLabelSoftMarginLoss ()), " torch::nn::MultiLabelSoftMarginLoss()" );
17341760}
17351761
1762+ TEST_F (ModulesTest, PrettyPrintSoftMarginLoss) {
1763+ ASSERT_EQ (c10::str (SoftMarginLoss ()), " torch::nn::SoftMarginLoss()" );
1764+ }
1765+
17361766TEST_F (ModulesTest, PrettyPrintCosineSimilarity) {
17371767 ASSERT_EQ (
17381768 c10::str (CosineSimilarity ()),
0 commit comments