Skip to content

Conversation

@goldsborough
Copy link
Contributor

@goldsborough goldsborough commented Dec 22, 2017

PyTorch is known to have high start-up times, to a large part due to slow tensor initialization/filling when calling normal_. Torch's normal function uses the Box-Mueller transform to produce Gaussian distributed floats from uniform floats. I did some benchmarks and found that generating normal numbers took around 5 times longer than generating only uniform floats, suggesting that the current normal sampling code was the bottleneck.

This PR addresses this by introducing a vectorized version of the Box-Mueller transform that essentially does the same thing, but for 8 values at a time. This version is called only for floats, only if we have AVX2, only if there are more than 16 values (due to implementation) and only if the tensor is contiguous. However, this should cover like 90%-95% of real-world cases where it's currently slow.

My initial, small-scale benchmarks show a 5x-6x speed-up:

Before:

In [1]: import torch
In [2]: x = torch.Tensor(10000, 10000)
In [3]: %time _ = x.normal_(0, 1)
CPU times: user 3.45 s, sys: 111 ms, total: 3.57 s
Wall time: 3.57 s

After:

In [1]: import torch
In [2]: x = torch.Tensor(10000, 10000)
In [3]: %time _ = x.normal_(0, 1)
CPU times: user 611 ms, sys: 1.07 ms, total: 612 ms
Wall time: 613 ms

Which looks pretty good. I will see how this affects loading a model like imagenet or similar.

CC @zdevito

Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good, I only have two minor comments.

My only concern is that it doesn't only vectorize the function, but it will also make it return different results on AVX2 and non-AVX2 platforms for the same random seed. Not sure how much do we care about cross-platform reproducibility, since that's very hard and constraining, but it's worth noting that. normal is quite important because it's used for initializing weights. @soumith thoughts?

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@apaszke
Copy link
Contributor

apaszke commented Dec 22, 2017

Oh, also it would be nice to evaluate some of the real examples (e.g. word language model form our repo), just to make sure we won't get severe downclocking penalty because of AVX2.

@goldsborough
Copy link
Contributor Author

goldsborough commented Dec 22, 2017

Thanks for the comments. About to board a flight, so will address them in a bit. Just a note: the non-AVX2 code can be re-written to layout the numbers in the same order as the AVX2 version (8 at a time, interleaved). Something like https://gist.github.com/goldsborough/75ee1802110eda71517cc33ea3c59a88. Then they would be the same (on my benchmarks this "unrolled" version actually is quite a bit faster than a simple loop).

Copy link
Contributor

@zdevito zdevito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great!

Changing the serial version to mimic the AVX one seems like a good idea since it increases the perf. of the serial one and also makes it the same as the avx one.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

Copy link
Contributor

@soumith soumith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks pretty good, needs some runtime dispatch changes (see comments in-line)

This comment was marked as off-topic.

@goldsborough
Copy link
Contributor Author

goldsborough commented Jan 1, 2018

Changes

Latest 2 commits make the following changes:

  • Actually use the mean and standard deviation (was not using the variables before and thus generating unit Gaussian samples),
  • Added a scalar normal_fill function that interleaves values just like the vectorized code, so that there is no difference in generated samples between AVX and non-AVX platforms. This function is also around 1.5x faster than the old version. Thus also non-AVX versions get a speedup (for contiguous tensors with at least 16 values).
  • Implemented all the vector dispatch stuff
  • Using int64_t for size
  • Using THAssert instead of assert
  • AVX code uses _m256_loadu_ps instead of _m256_load_ps, i.e. misaligned loads so that it works also for misaligned data,
  • Using explicit FMA _m256_fmadd_ps instruction instead of multiply + add (clang doesn't do this automatically, GCC does ... it's less code anyway. Also FMA is available wherever AVX is on Intel and AMD chips, so should be fine). Marginally faster.

Benchmarks

This time compiling with GCC 7.

Microbenchmarks

10,000 x 10,000 (float/AVX): 3.3s -> 0.48s (6.875x speedup)
1,000 x 1,000 (float/AVX): 35ms -> 4.9ms (7.1x speedup)
10,000 x 10,000 (double/scalar): 3.2s -> 2.1s (1.5x speedup)

float/AVX here means it's using the vectorized version, double/scalar means it's the interleaved scalar function that I added, since the vectorized version is only called for floats.

Imagenet Startup Times

VGG19: 5.61s -> 1.62s (3.5x speedup)
ResNet101: 2.4s -> 0.65s (3.7x speedup)
ResNet50: 1.21s -> 0.46s (2.6x speedup)

(Please re-review the code @colesbury @zdevito @soumith )

Happy new year 🎆 🎉

@soumith
Copy link
Contributor

soumith commented Jan 1, 2018

@pytorchbot add to whitelist

Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! I think there's one small bug though

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@apaszke
Copy link
Contributor

apaszke commented Jan 2, 2018

Should be ready to merge but the builds are failing now (probably because the seed we used previously is unlucky after this change)

@goldsborough
Copy link
Contributor Author

goldsborough commented Jan 2, 2018

Ok, how do we resolve the build failure? It says something about not being able to "get pull request builder trigger".

@apaszke
Copy link
Contributor

apaszke commented Jan 2, 2018

Oh this one looks like a CI failure, but the CUDA jobs manage to build and fail at test time

@zdevito
Copy link
Contributor

zdevito commented Jan 2, 2018

@pytorchbot retest this please

@goldsborough goldsborough force-pushed the master branch 4 times, most recently from 8a719ff to b983dcd Compare January 3, 2018 08:42
@goldsborough
Copy link
Contributor Author

No luck with random seeds for cudnn builds. I will need a gpu machine to figure out what's wrong locally. Or is there any smarter way of solving these random failures other than finding a lucky seed?

@apaszke
Copy link
Contributor

apaszke commented Jan 3, 2018

Alright, I looked into it and it seems that the test that's failing now is particularily flaky when using half (63 failures / 1000 trials). Reducing the scale of values used to test it makes the absolute errors smaller, and it succeeded 10000 times now. Here's the patch (just add / 2 in two places):

--- a/test/test_nn.py                                                                    
+++ b/test/test_nn.py                                                                    
@@ -2132,9 +2132,9 @@ class TestNN(NNTestCase):                                          
                 continue                                                                
             for depth_multiplier in [1, 2]:                                             
                 m = nn.Conv2d(2, 2 * depth_multiplier, kernel_size=3, groups=2).type(tp)
-                i = Variable(torch.randn(2, 2, 6, 6).type(tp), requires_grad=True)      
+                i = Variable(torch.randn(2, 2, 6, 6).type(tp) / 2, requires_grad=True)  
                 output = m(i)                                                           
-                grad_output = torch.randn(2, 2 * depth_multiplier, 4, 4).type(tp)       
+                grad_output = torch.randn(2, 2 * depth_multiplier, 4, 4).type(tp) / 2   
                 output.backward(grad_output)                                            
                                                                                         
                 offset = 1 * depth_multiplier                                                                                      

@goldsborough
Copy link
Contributor Author

That worked, thanks Adam! Now one of the builds got stuck after a segfault in the dataloader tests. Seems flaky too.

@goldsborough
Copy link
Contributor Author

😍

@apaszke
Copy link
Contributor

apaszke commented Jan 3, 2018

There's already other PR open that fixes the data loader thing. The test had a race condition that has started to appear only recently

@apaszke apaszke merged commit 77c792e into pytorch:master Jan 3, 2018
@apaszke
Copy link
Contributor

apaszke commented Jan 3, 2018

Thanks Peter!

yf225 pushed a commit to yf225/pytorch that referenced this pull request Jan 4, 2018
@yf225 yf225 mentioned this pull request Jan 4, 2018
ezyang pushed a commit that referenced this pull request Jan 4, 2018
@aluo-x
Copy link

aluo-x commented Jan 25, 2018

Not sure if this warrents a new bug report. But today while trying to build Pytorch on Windows, I ran into the following error:

"C:\optimae\pytorch\torch\lib\build\ATen\INSTALL.vcxproj" (default target) (1) ->
"C:\optimae\pytorch\torch\lib\build\ATen\ALL_BUILD.vcxproj" (default target) (3) ->
"C:\optimae\pytorch\torch\lib\build\ATen\src\ATen\ATen.vcxproj" (default target) (4) ->
(ClCompile target) ->
  C:\optimae\pytorch\aten\src\TH\vector\AVX2.c(60): error C2440: 'function': cannot convert from
'int' to '__m256' [C:\optimae\pytorch\torch\lib\build\ATen\src\ATen\ATen.vcxproj]

    8925 Warning(s)
    1 Error(s)

@soumith soumith added 0.3.1 and removed 0.3.1 labels Feb 4, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants