Skip to content

Conversation

@Stonesjtu
Copy link
Contributor

I'm trying to write a multi-gpu network by pipelining some layers onto different GPUs. However, the current gradient clip requires all the parameters to locate in the same device.

The overhead of CUDA launch is reduced since the scalar calculation is performed on CPU, but it introduces extra data transfers.

No performance regression is observed by running the following snippet:

import time

import torch

module = torch.nn.Sequential(
    torch.nn.LSTM(1024, 1024),
    torch.nn.LSTM(256, 256),
    torch.nn.Linear(100, 10000),
).cuda()

# warming-up
torch.nn.utils.clip_grad_norm_(module.parameters(), 1)
torch.cuda.synchronize()
start = time.time()
for _ in range(1000):
    torch.nn.utils.clip_grad_norm_(module.parameters(), 1)
torch.cuda.synchronize()
time_elapse = time.time() - start
print('{} ms per clip'.format(time_elapse))

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@soumith is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@soumith
Copy link
Contributor

soumith commented Jul 10, 2018

thank you, this looks good!

@apaszke
Copy link
Contributor

apaszke commented Jul 10, 2018

Wouldn’t it be much faster to just use a defaultdict to accumulate norms on different devices and only them transfer them all to CPU? That would at least cover the common case of all params on a single GPU

@Stonesjtu
Copy link
Contributor Author

@apaszke I've thought about that, but counter-intuitively on my environment, the scalar addition on a single device does not run as fast as expected.

My envs:
1080 + Xeon E5 2620 V4.

goodlux pushed a commit to goodlux/pytorch that referenced this pull request Aug 15, 2018
Summary:
I'm trying to write a multi-gpu network by pipelining some layers onto different GPUs. However, the current gradient clip requires all the parameters to locate in the same device.

The overhead of CUDA launch is reduced since the scalar calculation is performed on CPU, but it introduces extra data transfers.

No performance regression is observed by running the following snippet:
```python
import time

import torch

module = torch.nn.Sequential(
    torch.nn.LSTM(1024, 1024),
    torch.nn.LSTM(256, 256),
    torch.nn.Linear(100, 10000),
).cuda()

torch.nn.utils.clip_grad_norm_(module.parameters(), 1)
torch.cuda.synchronize()
start = time.time()
for _ in range(1000):
    torch.nn.utils.clip_grad_norm_(module.parameters(), 1)
torch.cuda.synchronize()
time_elapse = time.time() - start
print('{} ms per clip'.format(time_elapse))
```
Pull Request resolved: pytorch#9302

Differential Revision: D8781551

Pulled By: soumith

fbshipit-source-id: 9d76d01fe0531927f770a16b9523872a7e08e927
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants