Skip to content

Conversation

@t-vi
Copy link
Collaborator

@t-vi t-vi commented Apr 27, 2018

This adds tests for NaN to the min and max kernels and should return NaNs in the right places.

inline __device__ T operator()(T a, T b) const {
return THCNumerics<T>::lt(a, b) ? a : b;
// a != a means a == NaN
return (THCNumerics<T>::lt(a, b) ||

This comment was marked as off-topic.

This comment was marked as off-topic.

actual = f(a.cuda()).cpu()
expected = f(a).cpu()
self.assertEqual(torch.isnan(actual), torch.isnan(expected), 'nans for {}'.format(name))
self.assertEqual(actual[~torch.isnan(actual)],

This comment was marked as off-topic.

This comment was marked as off-topic.

@t-vi
Copy link
Collaborator Author

t-vi commented Apr 28, 2018 via email

Thank you ngimel and zou3519!
@ezyang
Copy link
Contributor

ezyang commented Apr 30, 2018

Looks correct. @t-vi would you mind running some quick perf numbers just to characterize what the effect of the extra neq test is?

@t-vi
Copy link
Collaborator Author

t-vi commented Apr 30, 2018

My conclusion would be measurable in isolation (in the order 10%-20% slowdown), negligible in any other context. Just like max probably wasn't the bottleneck before, doing twice as many comparisons in max doesn't really matter if you do any nontrivial stuff elsewhere.

So I run this several times (to avoid init topics, the first run takes longer than the others). I lifted this from somewhere in the pytorch/benchmark repository - if you have a better methodology, I'd be happy to apply it.

import gc, torch, time
with torch.no_grad():
    a = torch.empty(100, 1000, 1000, device='cuda')
    a.normal_()
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    gc.collect()
    torch.cuda.synchronize()
    start.record()
    start_cpu_secs = time.time()
    b = a.max()
    end_cpu_secs = time.time()
    end.record()
    torch.cuda.synchronize()
    gpu_msecs = start.elapsed_time(end)
    print(torch.__version__, "msecs maxall gpu", gpu_msecs, "cpu", (end_cpu_secs - start_cpu_secs)*1000)
    if 1:
        a = torch.empty(100, 1000, 1000, device='cuda')
        a.normal_()
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        gc.collect()
        torch.cuda.synchronize()
        start.record()
        start_cpu_secs = time.time()
        b = a.max(1)
        end_cpu_secs = time.time()
        end.record()
        torch.cuda.synchronize()
        gpu_msecs = start.elapsed_time(end)
        print(torch.__version__, "msecs max[1] gpu", gpu_msecs, "cpu", (end_cpu_secs - start_cpu_secs)*1000)

And I get:

0.4.0 msecs maxall gpu 1.6861120462417603 cpu 1.6646385192871094
0.4.0 msecs max[1] gpu 2.239487886428833 cpu 0.07939338684082031

vs.

0.5.0a0+497ff06 msecs maxall gpu 1.8479679822921753 cpu 1.8224716186523438
0.5.0a0+497ff06 msecs max[1] gpu 1.922752022743225 cpu 0.054836273193359375

Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wanted to clarify our strategy here. Do we want to match the CPU operators? If so, that's not the right way to go. The way they are implemented is slightly different. Actually I think the CPU ops mix both the condition I linked, and the one you put here. Can you please make them all consistent?

The benchmarks look like noise to me, so that's ok.

@t-vi
Copy link
Collaborator Author

t-vi commented Apr 30, 2018 via email

@apaszke
Copy link
Contributor

apaszke commented Apr 30, 2018

But what's the actual fix in this PR then? If we want to have kernels return NaNs when there are NaNs, then why don't we treat CPU kernels that way

@apaszke
Copy link
Contributor

apaszke commented Apr 30, 2018

Ok, never mind my comment. I did the math again and it seems to be ok.

@t-vi
Copy link
Collaborator Author

t-vi commented Apr 30, 2018

(Sorry, the second mail doesn't seem to have reached the issue log... :( )
Oh, it is roughly the same, CPU does the comparison first and then checks if value was NaN (then theMax is NaN because how the ">" is formulated) and if it is, it aborts and theMax will be NaN.
Previously, the Cuda kernel would put NaN into the accumulator (a in how the cuda function is called, theMax in CPU) and then overwrite it when doing the next comparison.
With the proposed fix, the cuda kernel form checks if the left hand side (a - equivalent to theMax in how it is used) is NaN and if it is, it keeps a. I think this is result-equivalent to how the CPU works.

@apaszke apaszke merged commit 20c965f into pytorch:master Apr 30, 2018
Jorghi12 pushed a commit to wsttiger/pytorch that referenced this pull request May 10, 2018
weiyangfb pushed a commit to weiyangfb/pytorch that referenced this pull request Jun 11, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants