Skip to content

Conversation

@cpuhrsch
Copy link
Contributor

@cpuhrsch cpuhrsch commented May 8, 2018

This PR uses Vec256 to vectorize the softmax and logsoftmax Layers.

This comes in 4 steps:

  1. log_softmax
  2. softmax
  3. log_softmax_backward
  4. softmax_backward

I will remove the SLEEF commit, once the corresponding PR lands.

@cpuhrsch
Copy link
Contributor Author

cpuhrsch commented May 8, 2018

First single core timings

Master

log_softmax     memory: O(10^4)KB  count: 25     size: [215, 215, 215]      stride: ['  46225', '    215', '      1']                            numel: 9938375   type: torch.FloatTensor  dim: 3     elapsed:   5.5386
softmax         memory: O(10^4)KB  count: 25     size: [215, 215, 215]      stride: ['  46225', '    215', '      1']                            numel: 9938375   type: torch.FloatTensor  dim: 3     elapsed:   3.8424
log_softmax     memory: O(10^4)KB  count: 25     size: [215, 215, 215]      stride: ['  46225', '    215', '      1']                            numel: 9938375   type: torch.DoubleTensor dim: 3     elapsed:   9.2648
softmax         memory: O(10^4)KB  count: 25     size: [215, 215, 215]      stride: ['  46225', '    215', '      1']                            numel: 9938375   type: torch.DoubleTensor dim: 3     elapsed:   6.4861


This branch

log_softmax     memory: O(10^4)KB  count: 25     size: [215, 215, 215]      stride: ['  46225', '    215', '      1']                            numel: 9938375   type: torch.FloatTensor  dim: 3     elapsed:   0.8699
softmax         memory: O(10^4)KB  count: 25     size: [215, 215, 215]      stride: ['  46225', '    215', '      1']                            numel: 9938375   type: torch.FloatTensor  dim: 3     elapsed:   3.8344
log_softmax     memory: O(10^4)KB  count: 25     size: [215, 215, 215]      stride: ['  46225', '    215', '      1']                            numel: 9938375   type: torch.DoubleTensor dim: 3     elapsed:   1.9706
softmax         memory: O(10^4)KB  count: 25     size: [215, 215, 215]      stride: ['  46225', '    215', '      1']                            numel: 9938375   type: torch.DoubleTensor dim: 3     elapsed:   6.4777

This comment was marked as off-topic.

This comment was marked as off-topic.

@cpuhrsch cpuhrsch force-pushed the vectorsoftmax branch 2 times, most recently from 4695f10 to e25f291 Compare May 8, 2018 19:46
@cpuhrsch
Copy link
Contributor Author

cpuhrsch commented May 8, 2018

Updated branch timings for softmax

log_softmax     memory: O(10^4)KB  count: 25     size: [215, 215, 215]      stride: ['  46225', '    215', '      1']                            numel: 9938375   type: torch.FloatTensor  dim: 3     elapsed:   0.8706
softmax         memory: O(10^4)KB  count: 25     size: [215, 215, 215]      stride: ['  46225', '    215', '      1']                            numel: 9938375   type: torch.FloatTensor  dim: 3     elapsed:   0.9233
log_softmax     memory: O(10^4)KB  count: 25     size: [215, 215, 215]      stride: ['  46225', '    215', '      1']                            numel: 9938375   type: torch.DoubleTensor dim: 3     elapsed:   1.8847
softmax         memory: O(10^4)KB  count: 25     size: [215, 215, 215]      stride: ['  46225', '    215', '      1']                            numel: 9938375   type: torch.DoubleTensor dim: 3     elapsed:   1.9499

@ezyang
Copy link
Contributor

ezyang commented May 8, 2018

This will probably conflict with #7275 which reorganize the Unary/Reduce op prologues so that the CUDA-only functions live in the cuda/ subdir.

@ajtulloch
Copy link
Contributor

What's the plan for enabling vectorization for these functions with NEON (128-bit wide) (or AVX512 (512-bit wide)? Vec256 seems pretty AVX2 specific. Is the idea you'd implement it N different times, once for each bitwidth?

@apaszke
Copy link
Contributor

apaszke commented May 9, 2018

@ajtulloch I think in that case we could just keep 2 128-bit wide registers as a single value (similarly to how Vec256 is implemented when AVX is disabled). AVX-512 is a completely different thing, but it's unclear if we want to go down this path, considering all the issues with downclocking.

@cpuhrsch
Copy link
Contributor Author

cpuhrsch commented May 9, 2018

@ajtulloch The code can definitely handle different bit width as it's only using Vec::size and you can ifdef on CPU_CAPABILITY_, but it's not clear if that'll immediately be a good idea without investigating case by case. As for ARM, the plan is to write only kernels specific to CPU for now.

What do you think would be a good strategy?

@cpuhrsch cpuhrsch force-pushed the vectorsoftmax branch 2 times, most recently from 549cdf4 to 2e39a2f Compare May 9, 2018 18:33
@cpuhrsch
Copy link
Contributor Author

cpuhrsch commented May 9, 2018

I have a character RNN locally and it produces NaN when using this diff (inf loss), yet all test_nn tests succeed. So, let's not merge this even if all tests succeed.

@ajtulloch
Copy link
Contributor

ajtulloch commented May 9, 2018

@apaszke AVX2 also has issues with downclocking, it’s not safe to just unconditionally enable it btw. @cpuhrsch ARM CPUs are CPUs too, I assume you meant you only want to handle Intel CPUs?

It seems like the right abstraction is basically the Eigen Packet abstraction since that’s essentially what Vec256 is a special case of (packet is basically a C++ wrapper around a generic SIMD type, so Vec256 is essentially a Packet<T, 256>, so if this is generalized to more than AVX2 then IMO that makes sense.

@cpuhrsch
Copy link
Contributor Author

cpuhrsch commented May 9, 2018

@ajtulloch Yes, that's correct. x86 cpus. I'm associating ARM with mobile and somehow that then turned into "not CPU". There's also something to be said about wanting to split ARM and x86 into two separate categories (of course they can still share), because you'd probably want the aten library to be very small for mobile things. We can do something as in this PR. But that can probably also be done at a CMake level.

@cpuhrsch cpuhrsch force-pushed the vectorsoftmax branch 4 times, most recently from c0fad45 to 86cd5d1 Compare May 10, 2018 05:31

This comment was marked as off-topic.

Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think quite a bit of this code could be simplified, using a higher-level vec_reduce helper. We actually have something like this in a CUDA implementation, and it generally ends up being much more readable.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@cpuhrsch cpuhrsch force-pushed the vectorsoftmax branch 5 times, most recently from 19d3d9c to 4154a9b Compare May 11, 2018 21:41
@cpuhrsch
Copy link
Contributor Author

One key concern that is left, is whether or not it's important to accumulate sums into doubles or not. On my local example (character rnn) this appears to make no difference. If someone needs that extra precision, they could also upcast all to doubles.

@cpuhrsch
Copy link
Contributor Author

cpuhrsch commented May 11, 2018

I'm seeing a 4.5x on single core and 1.8x on 10 threads using this code

import torch

if __name__ == "__main__":
    softmax = torch.nn.LogSoftmax(dim=1)
    output = torch.randn(1000, 100).type('torch.FloatTensor')
    for _ in range(10000):
        output = softmax(output)

and these commands

taskset -c 0 perf stat python softmax_comp.py
taskset -c 0-9 perf stat python softmax_comp.py

@cpuhrsch cpuhrsch force-pushed the vectorsoftmax branch 2 times, most recently from 4a87cf0 to 6add1a7 Compare May 11, 2018 23:44
@cpuhrsch
Copy link
Contributor Author

Deciding on whether to accumulate into doubles or floats from a perf perspective.

I'm using the follow analysis to investigate whether this algorithm is compute-bound. If it is, we should accumulate into floats by default and advise users with numerical stability issues to upcast to doubles as that case will be rare.

My development machine has two NUMA nodes with the following layout

$ numactl --hardware
available: 2 nodes (0-1)
node 0 cpus: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
node 0 size: 257863 MB
node 0 free: 222621 MB
node 1 cpus: 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
node 1 size: 258041 MB
node 1 free: 248925 MB
node distances:
node   0   1
  0:  10  21
  1:  21  10

Any kernel that is to some extent memory bound will see heavy penalties from being scheduled on node 0 cpus, but have their memory bound to node 1 memory, since using the QPI is much slower than reading from local memory.

I'm using the following code to get some measurements, which triggers the vectorized branches. The size of the input tensor is chosen to be roughly representative of a typical (larger) workload. It has a batchsize of 128 and 20k classes. There are two investigations: 1) Set the cpu frequency to different values 2) Bind memory to a different NUMA node.

I'm using the following code

import torch
import time
import gc

if __name__ == "__main__":
    softmax = torch.nn.LogSoftmax(dim=1)
    output = torch.randn(256, 20000).type('torch.FloatTensor')
    gc.collect()
    tstart = time.time()
    for _ in range(250):
        numbers = softmax(output)
    print("elapsed: " + str(time.time() - tstart))

For 1) we can observe

$ perf stat numactl --membind=0 taskset -c 0,40 python softmax_comp.py
elapsed: 3.24645280838

 Performance counter stats for 'numactl --membind=0 taskset -c 0,40 python softmax_comp.py':

       4034.933398      task-clock (msec)         #    0.940 CPUs utilized
             1,637      context-switches          #    0.406 K/sec
                20      cpu-migrations            #    0.005 K/sec
            61,345      page-faults               #    0.015 M/sec
     8,871,752,268      cycles                    #    2.199 GHz
   <not supported>      stalled-cycles-frontend
   <not supported>      stalled-cycles-backend
    10,145,690,978      instructions              #    1.14  insns per cycle
     1,373,296,620      branches                  #  340.352 M/sec
         4,634,132      branch-misses             #    0.34% of all branches

       4.294467340 seconds time elapsed
$ perf stat numactl --membind=0 taskset -c 0,40 python softmax_comp.py
elapsed: 6.08645606041

 Performance counter stats for 'numactl --membind=0 taskset -c 0,40 python softmax_comp.py':

       7547.804302      task-clock (msec)         #    0.964 CPUs utilized
             1,652      context-switches          #    0.219 K/sec
                20      cpu-migrations            #    0.003 K/sec
            92,522      page-faults               #    0.012 M/sec
     9,048,895,951      cycles                    #    1.199 GHz
   <not supported>      stalled-cycles-frontend
   <not supported>      stalled-cycles-backend
    10,227,923,023      instructions              #    1.13  insns per cycle
     1,387,674,456      branches                  #  183.851 M/sec
         4,650,287      branch-misses             #    0.34% of all branches

       7.829675110 seconds time elapsed

For 2) we can observe

$ perf stat numactl --membind=1 taskset -c 0,40 python softmax_comp.py
elapsed: 3.40651988983

 Performance counter stats for 'numactl --membind=1 taskset -c 0,40 python softmax_comp.py':

       4187.342577      task-clock (msec)         #    0.971 CPUs utilized
               786      context-switches          #    0.188 K/sec
                12      cpu-migrations            #    0.003 K/sec
            93,049      page-faults               #    0.022 M/sec
     9,208,809,385      cycles                    #    2.199 GHz
   <not supported>      stalled-cycles-frontend
   <not supported>      stalled-cycles-backend
    10,216,458,069      instructions              #    1.11  insns per cycle
     1,385,948,756      branches                  #  330.985 M/sec
         4,573,052      branch-misses             #    0.33% of all branches

       4.314352108 seconds time elapsed
$ perf stat numactl --membind=1 taskset -c 0,40 python softmax_comp.py
elapsed: 6.14880990982

 Performance counter stats for 'numactl --membind=1 taskset -c 0,40 python softmax_comp.py':

       7595.598948      task-clock (msec)         #    0.982 CPUs utilized
               798      context-switches          #    0.105 K/sec
                14      cpu-migrations            #    0.002 K/sec
            61,877      page-faults               #    0.008 M/sec
     9,108,585,035      cycles                    #    1.199 GHz
   <not supported>      stalled-cycles-frontend
   <not supported>      stalled-cycles-backend
    10,139,876,776      instructions              #    1.11  insns per cycle
     1,372,300,483      branches                  #  180.670 M/sec
         4,619,897      branch-misses             #    0.34% of all branches

       7.732898021 seconds time elapsed

In conclusion, we see a decrease in wall-time (3.2s to 6.0s and 3.4s to 6.1s on both memory nodes) proportional to a decrease in cpu frequency (2.2GHz to 1.2 GHz), however only a very small decrease in wall-time when allocating memory on a NUMA node away from the used CPUs.

This leads me to believe that the kernel is compute bound on a single core and my particular computer.

I've also compare the wall-time of this kernel with master. Even when using doubles, we still get a 2x improvement in runtime (on a single core) in comparison to the current implementation. With floats we get around a ~5.5x improvement.

Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks nice. Have some final minor comments and should be good to go. I'm not sure where is the grad formula for softmax coming from, since it doesn't seem to match what we do e.g. for CUDA.

This comment was marked as off-topic.

Vec output_vec = vec_fun(data_vec, data_vec2);
output_vec.store(output_data + d, size - d);
}
}

This comment was marked as off-topic.

This comment was marked as off-topic.

case 6:
return blend<63>(a, b);
case 7:
return blend<127>(a, b);

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

#include "intrinsics.h"
#include "vec256_base.h"

//TODO: Add tests for partial loads

This comment was marked as off-topic.

This comment was marked as off-topic.

case 6:
return blend<63>(a, b);
case 7:
return blend<127>(a, b);

This comment was marked as off-topic.

This comment was marked as off-topic.

int64_t i = ii + j;
scalar_t* input_data = input_data_base + i * dim_size;
max_input_arr[j] = vec256::reduce_all<scalar_t>(
[](Vec& x, Vec& y) { return vec256::max(x, y); },

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

Vec::loadu(max_input_arr + j, loop_end - j) +
Vec::loadu(tmp_sum_scalar + j, loop_end - j).log();
tmp_sum_scalar_vec.store(tmp_sum_scalar + j, loop_end - j);
}

This comment was marked as off-topic.

This comment was marked as off-topic.

static tbb::affinity_partitioner ap;

template <int64_t size>
inline int64_t _leftover(int64_t x, int64_t y) {

This comment was marked as off-topic.

dim_size);
} else {
vec256::map2(
[sum](Vec x, Vec y) { return (x - Vec(sum)) * y; },

This comment was marked as off-topic.

This comment was marked as off-topic.

[](Vec x, Vec y) { return x + y; },
grad_data,
output_data,
dim_size);

This comment was marked as off-topic.

This comment was marked as off-topic.

@ezyang ezyang merged commit 0585394 into pytorch:master May 15, 2018
cpuhrsch added a commit to cpuhrsch/pytorch that referenced this pull request May 16, 2018
weiyangfb pushed a commit to weiyangfb/pytorch that referenced this pull request Jun 11, 2018
This PR uses Vec256 to vectorize the softmax and logsoftmax Layers.

This comes in 4 steps:

log_softmax
softmax
log_softmax_backward
softmax_backward

* Vectorized Softmax and LogSoftmax

* Abstractions

* Style

* Remove <limits> for Kernel

* Perf investigations

* Last cleanups
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants