Skip to content

Commit b403b10

Browse files
wolegechufacebook-github-bot
authored andcommitted
Fix #11752: fix numerical issue in log_softmax (#21672)
Summary: #11866 has corrected this issue in function `host_softmax` (aten/src/ATen/native/SoftMax.cpp). But I tried the example proposed in #11752. `log_softmax` is still not working for big logits. I have looked into the source code, found that example had called `vec_host_softmax_lastdim`, not `host_softmax`. This code fixes the issue in `_vec_log_softmax_lastdim` and has a test for `log_softmax`. Pull Request resolved: #21672 Differential Revision: D15856327 Pulled By: VitalyFedyunin fbshipit-source-id: 7a1fd3c0a03d366c99eb873e235361e4fcfa7567
1 parent 0f675f9 commit b403b10

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

aten/src/ATen/native/cpu/SoftMaxKernel.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,19 +65,25 @@ inline void _vec_log_softmax_lastdim(
6565
}
6666
// See [Note AVX-SSE transitions] for why this should call the
6767
// vectorized version (aside from perf improvements).
68-
vec256::map2(
69-
[](Vec x, Vec y) { return x.log() + y; },
68+
vec256::map(
69+
[](Vec x) { return x.log(); },
7070
tmp_sum_scalar,
7171
tmp_sum_scalar,
72-
max_input_arr,
7372
loop_end);
7473
for (int64_t j = 0; j < loop_end; j++) {
7574
int64_t i = ii + j;
7675
scalar_t* input_data = input_data_base + i * dim_size;
7776
scalar_t* output_data = output_data_base + i * dim_size;
7877
scalar_t tmp_sum = tmp_sum_scalar[j];
78+
scalar_t max_input = max_input_arr[j];
79+
80+
// It's necessary to keep the order of the operations below.
81+
// In some cases that input is large digits and the difference
82+
// is small, if we compute `max_input` plus `tmp_sum` before,
83+
// there would be a numerical problem. See an example in
84+
// https://github.com/pytorch/pytorch/issues/11752#issuecomment-422883379
7985
vec256::map(
80-
[tmp_sum](Vec x) { return x - Vec(tmp_sum); },
86+
[tmp_sum, max_input](Vec x) { return x - Vec(max_input) - Vec(tmp_sum); },
8187
output_data,
8288
input_data,
8389
dim_size);

test/test_nn.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8153,6 +8153,11 @@ def test_softmin(self):
81538153
self.assertEqual(F.softmin(x, 1), F.softmax(-x, 1))
81548154
self.assertEqual(F.softmin(x, 0), F.softmax(-x, 0))
81558155

8156+
def test_log_softmax(self):
8157+
x_small = torch.ones(1, 2, dtype=torch.float32)
8158+
x_big = x_small + 1e16
8159+
self.assertEqual(F.log_softmax(x_small, -1), F.log_softmax(x_big, -1))
8160+
81568161
def test_adaptive_log_softmax(self):
81578162
# args validation
81588163
with self.assertRaises(ValueError):

0 commit comments

Comments
 (0)