Skip to content

Conversation

@mingfeima
Copy link
Collaborator

This PR parallels masked_fill on CPU, currently it runs in sequential on CPU.

the following script is used to benchmark and verify this PR. On Xeon skylake 8180 (2 sockets * 28 cores),
it runs 4.20 sec without the PR and 0.11 sec with the PR.

import torch
import random
from time import time

size = 10 * 1000 * 1000
count = 100

def test_masked_fill():
    dst = torch.randn(size)
    dst_ = dst.clone()
    mask = torch.rand(size).mul(2).floor().byte()
    val = random.random()

    tstart = time()
    for i in range(count):
        dst.masked_fill_(mask, val)
    tend = time()
    print("masked_fill_: %f" % (tend-tstart))

    for i in range(size):
        if mask[i]:
            if dst[i] != val:
                print("fail")
        else:
            if dst[i] != dst_[i]:
                print("fail1")
    print("test_masked_fill: PASS")

test_masked_fill()

}
#else
serial_path = 1;
#endif

This comment was marked as off-topic.

@mingfeima
Copy link
Collaborator Author

some caffe2 ci failed with could not create cache path /usr/local/caffe2/lib/python2.7/dist-packages/caffe2/python/.pytest_cache/v/cache/lastfailed
some pytorch ci failed with test_all_reduce_product from test_distributed.py.
i can't reproduce the fail locally, can someone give me some guidance?

@ssnl
Copy link
Collaborator

ssnl commented Sep 7, 2018

@mingfeima ignore the circle ci ones. they are experimental.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

zdevito pushed a commit to zdevito/ATen that referenced this pull request Sep 9, 2018
Summary:
This PR parallels `masked_fill` on CPU, currently it runs in sequential on CPU.

the following script is used to benchmark and verify this PR. On Xeon skylake 8180 (2 sockets * 28 cores),
 it runs `4.20` sec without the PR and `0.11` sec with the PR.

```python
import torch
import random
from time import time

size = 10 * 1000 * 1000
count = 100

def test_masked_fill():
    dst = torch.randn(size)
    dst_ = dst.clone()
    mask = torch.rand(size).mul(2).floor().byte()
    val = random.random()

    tstart = time()
    for i in range(count):
        dst.masked_fill_(mask, val)
    tend = time()
    print("masked_fill_: %f" % (tend-tstart))

    for i in range(size):
        if mask[i]:
            if dst[i] != val:
                print("fail")
        else:
            if dst[i] != dst_[i]:
                print("fail1")
    print("test_masked_fill: PASS")

test_masked_fill()
```
Pull Request resolved: pytorch/pytorch#11359

Differential Revision: D9735578

Pulled By: ezyang

fbshipit-source-id: d437ad7c6dace1910d0c18d6d9ede80efb44fae4
PenghuiCheng pushed a commit to PenghuiCheng/pytorch that referenced this pull request Sep 11, 2018
Summary:
This PR parallels `masked_fill` on CPU, currently it runs in sequential on CPU.

the following script is used to benchmark and verify this PR. On Xeon skylake 8180 (2 sockets * 28 cores),
 it runs `4.20` sec without the PR and `0.11` sec with the PR.

```python
import torch
import random
from time import time

size = 10 * 1000 * 1000
count = 100

def test_masked_fill():
    dst = torch.randn(size)
    dst_ = dst.clone()
    mask = torch.rand(size).mul(2).floor().byte()
    val = random.random()

    tstart = time()
    for i in range(count):
        dst.masked_fill_(mask, val)
    tend = time()
    print("masked_fill_: %f" % (tend-tstart))

    for i in range(size):
        if mask[i]:
            if dst[i] != val:
                print("fail")
        else:
            if dst[i] != dst_[i]:
                print("fail1")
    print("test_masked_fill: PASS")

test_masked_fill()
```
Pull Request resolved: pytorch#11359

Differential Revision: D9735578

Pulled By: ezyang

fbshipit-source-id: d437ad7c6dace1910d0c18d6d9ede80efb44fae4
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants