Skip to content
Closed
14 changes: 10 additions & 4 deletions aten/src/ATen/native/cpu/SoftMaxKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,19 +65,25 @@ inline void _vec_log_softmax_lastdim(
}
// See [Note AVX-SSE transitions] for why this should call the
// vectorized version (aside from perf improvements).
vec256::map2(
[](Vec x, Vec y) { return x.log() + y; },
vec256::map(
[](Vec x) { return x.log(); },
tmp_sum_scalar,
tmp_sum_scalar,
max_input_arr,
loop_end);
for (int64_t j = 0; j < loop_end; j++) {
int64_t i = ii + j;
scalar_t* input_data = input_data_base + i * dim_size;
scalar_t* output_data = output_data_base + i * dim_size;
scalar_t tmp_sum = tmp_sum_scalar[j];
scalar_t max_input = max_input_arr[j];

// It's necessary to keep the order of the operations below.
// In some cases that input is large digits and the difference
// is small, if we compute `max_input` plus `tmp_sum` before,
// there would be a numerical problem. See an example in
// https://github.com/pytorch/pytorch/issues/11752#issuecomment-422883379
vec256::map(
[tmp_sum](Vec x) { return x - Vec(tmp_sum); },
[tmp_sum, max_input](Vec x) { return x - Vec(max_input) - Vec(tmp_sum); },
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the compiler optimize this into computing max_input and tmp_sum before the :map?

Copy link
Contributor Author

@wolegechu wolegechu Jun 12, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It shouldn’t be optimized to compute before.

In some cases, that input is large digits and the difference is small so that the preprocessed input value would be ignored (in the old way). You can see the example here. #11752 (comment)

And I have looked up the log_sofmax implementation in MXNet and TensorFlow. They also write like x - Vec(max_input) - Vec(tmp_sum) together to ensure the computing order.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment here so people won't "optimize" this away in future? Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's done.

output_data,
input_data,
dim_size);
Expand Down
5 changes: 5 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7847,6 +7847,11 @@ def test_softmin(self):
self.assertEqual(F.softmin(x, 1), F.softmax(-x, 1))
self.assertEqual(F.softmin(x, 0), F.softmax(-x, 0))

def test_log_softmax(self):
x_small = torch.ones(1, 2, dtype=torch.float32)
x_big = x_small + 1e16
self.assertEqual(F.log_softmax(x_small, -1), F.log_softmax(x_big, -1))

def test_adaptive_log_softmax(self):
# args validation
with self.assertRaises(ValueError):
Expand Down