Skip to content

Conversation

@mingfeima
Copy link
Collaborator

optimize max and min reduction for ATen CPU path, current code path from TH module runs in sequential on CPU.

@fmassa
Copy link
Member

fmassa commented Aug 8, 2018

cc @colesbury

@mingfeima
Copy link
Collaborator Author

can someone take a look at the build fail:
https://ci.pytorch.org/jenkins/job/caffe2-builds/job/conda2-macos10.13-build/9853/console
doesn't seem to be a compilation issue.

@ssnl
Copy link
Collaborator

ssnl commented Aug 8, 2018

Feel free to ignore that one.

Do you have some benchmarks on this?

@mingfeima
Copy link
Collaborator Author

mingfeima commented Aug 9, 2018

@ssnl sure, i wrote a small benchmark for max, the piece of code reduces from 49ms to 3.5ms on Xeon Skylake 8180.

import torch
from time import time

N = 2000
T = 35820
warmups = 100
count = 200

a = torch.randn(N, T)

def test_max():
    for i in range(warmups):
        b, _ = a.max(dim=1)
    tstart = time()
    for i in range(count):
        b, _ = a.max(dim=1)
    tend = time()
    print("max reduction : %f ms" % ((tend-tstart)/count*1000))

test_max()

I brought this up because i have been optimizing OpenNMT-py, max is used at loss calculation. With max being paralleled, the total time reduced roughly by 8%.

Copy link
Member

@colesbury colesbury left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice speed-ups. LGTM with a few code style comments.

I'm working on changing how reductions are implemented and unifying some of the CPU and CUDA code, but it'll probably take a while, so this speed-up is very welcome.


template <>
bool _isnan(float val) {
return std::isnan(val);

This comment was marked as off-topic.

This comment was marked as off-topic.

return std::isnan(val);
}

#define isnan_break(val) \

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

zdevito pushed a commit to zdevito/ATen that referenced this pull request Aug 15, 2018
Summary:
optimize max and min reduction for ATen CPU path, current code path from TH module runs in sequential on CPU.
Pull Request resolved: pytorch/pytorch#10343

Differential Revision: D9330799

Pulled By: ezyang

fbshipit-source-id: 5b8271e0ca3e3e73f88a9075aa541c8756001b7c
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants