Skip to content

Slowdown in distributions log_prob methods #12190

@neerajprad

Description

@neerajprad

🐛 Bug

One of Pyro's branches (pyro-ppl/pyro#1431) is running against PyTorch's nightly build, and we noticed that the unit test stage in CI is almost twice as slow as compared to the 0.4.0 release. Many of the slow tests turn out to be HMC tests (pyro-ppl/pyro#1421), and the slowdown seems to mostly be in the distribution's log_prob methods. Pasting the results below for the normal distribution, but I am seeing this for other distributions too.

To Reproduce

version: 0.4.0

>>> import torch
>>> import torch.distributions as dist
>>> torch.__version__
 '0.4.0'
>>> d = dist.Normal(torch.zeros(1000, 2), torch.ones(1000, 2))

>>> %timeit torch.randn(1000, 2)
16.2 µs ± 342 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

>>> %timeit [d.log_prob(torch.randn(1000, 2)) for _ in range(1000)]
45.8 ms ± 1.6 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

>>> %lprun -f d.log_prob [d.log_prob(torch.randn(1000, 2)) for _ in range(10000)]
Timer unit: 1e-06 s

Total time: 0.476309 s
File: /Users/npradhan/miniconda2/envs/pytorch-36/lib/python3.6/site-packages/torch/distributions/normal.py
Function: log_prob at line 62

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    62                                               def log_prob(self, value):
    63     10000       5511.0      0.6      1.2          if self._validate_args:
    64                                                       self._validate_sample(value)
    65                                                   # compute the variance
    66     10000      23427.0      2.3      4.9          var = (self.scale ** 2)
    67     10000     180867.0     18.1     38.0          log_scale = math.log(self.scale) if isinstance(self.scale, Number) else self.scale.log()
    68     10000     266504.0     26.7     56.0          return -((value - self.loc) ** 2) / (2 * var) - log_scale - math.log(math.sqrt(2 * math.pi))

master (1.0.0a0+ab9a597)

>>> import torch
>>> import torch.distributions as dist
>>> torch.__version__
 '1.0.0a0+ab9a597'

>>> d = dist.Normal(torch.zeros(1000, 2), torch.ones(1000, 2))

>>> %timeit torch.randn(1000, 2)
17.8 µs ± 132 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

>>> %timeit [d.log_prob(torch.randn(1000, 2)) for _ in range(1000)]
72.1 ms ± 1.78 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

>>> %lprun -f d.log_prob [d.log_prob(torch.randn(1000, 2)) for _ in range(10000)]
Timer unit: 1e-06 s

Total time: 0.782675 s
File: /Users/npradhan/miniconda2/envs/pytorch-master/lib/python3.6/site-packages/torch/distributions/normal.py
Function: log_prob at line 70

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    70                                               def log_prob(self, value):
    71     10000       8708.0      0.9      1.1          if self._validate_args:
    72                                                       self._validate_sample(value)
    73                                                   # compute the variance
    74     10000      46140.0      4.6      5.9          var = (self.scale ** 2)
    75     10000     135462.0     13.5     17.3          log_scale = math.log(self.scale) if isinstance(self.scale, Number) else self.scale.log()
    76     10000     592365.0     59.2     75.7          return -((value - self.loc) ** 2) / (2 * var) - log_scale - math.log(math.sqrt(2 * math.pi))

Note that the log_prob method hasn't changed, but the last line takes almost twice as long. The only thing I can think of is that the broadcast_all method in the constructor is now different, and the expanded instances are somehow slower. I am still investigating this.

Environment

  $ python collect_env.py
Collecting environment information...
PyTorch version: 1.0.0a0+ab9a597
Is debug build: No
CUDA used to build PyTorch: None

OS: Mac OSX 10.13.3
GCC version: Could not collect
CMake version: version 3.12.0

Python version: 3.6
Is CUDA available: No
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA

Versions of relevant libraries:
[pip] numpy (1.15.0)
[pip] torch (1.0.0a0+ab9a597, /Users/npradhan/miniconda2/envs/pytorch-master/lib/python3.6/site-packages)
[pip] torchfile (0.1.0)
[pip] torchvision (0.2.1)
[conda] torch                     1.0.0a0+ab9a597           <pip>
[conda] torch                     0.5.0a0+2431eac           <pip>
[conda] torch                     1.0.0a0+6ff568d           <pip>
[conda] torch                     0.5.0a0+6660a12           <pip>
[conda] torch                     0.5.0a0+35d52db           <pip>
[conda] torch                     0.5.0a0+6c3792b           <pip>
[conda] torchfile                 0.1.0                     <pip>
[conda] torchvision               0.2.1                     <pip>

Additional context

I also noticed that certain functions like torch.randn in the nightly build are almost 2X slower than compiling the source code on my system locally. That's the reason why I am benchmarking against the local build and not the pytorch nightly build.

Metadata

Metadata

Labels

module: performanceIssues related to performance, either of kernel code or framework glue

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions