Skip to content

Commit 8e33451

Browse files
ssnlfacebook-github-bot
authored andcommitted
Make torch.cuda.* take device objects; Update distributed docs (#10833)
Summary: Commits: 1. Make `torch.cuda.*` take device objects 2. Update `torch.distributed` docs to emphasize calling `torch.cuda.set_device` before `init_process_group` Pull Request resolved: #10833 Differential Revision: D9514241 Pulled By: SsnL fbshipit-source-id: 2497464305fb1e63d6c495291a5744aaa7e2696e
1 parent 58b145f commit 8e33451

File tree

8 files changed

+107
-62
lines changed

8 files changed

+107
-62
lines changed

docs/source/distributed.rst

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,8 @@ TCP initialization
8888

8989
There are two ways to initialize using TCP, both requiring a network address
9090
reachable from all processes and a desired ``world_size``. The first way
91-
requires specifying an address that belongs to the rank 0 process. This first way of
92-
initialization requires that all processes have manually specified ranks.
91+
requires specifying an address that belongs to the rank 0 process. This
92+
initialization method requires that all processes have manually specified ranks.
9393

9494
Alternatively, the address has to be a valid IP multicast address, in which case
9595
ranks can be assigned automatically. Multicast initialization also supports
@@ -101,10 +101,10 @@ jobs, as long as they use different group names.
101101
import torch.distributed as dist
102102

103103
# Use address of one of the machines
104-
dist.init_process_group(init_method='tcp://10.1.1.20:23456', rank=args.rank, world_size=4)
104+
dist.init_process_group(backend, init_method='tcp://10.1.1.20:23456', rank=args.rank, world_size=4)
105105

106106
# or a multicast address - rank will be assigned automatically if unspecified
107-
dist.init_process_group(init_method='tcp://[ff15:1e18:5d4c:4cf0:d02d:b659:53ba:b0a7]:23456',
107+
dist.init_process_group(backend, init_method='tcp://[ff15:1e18:5d4c:4cf0:d02d:b659:53ba:b0a7]:23456',
108108
world_size=4)
109109

110110
Shared file-system initialization
@@ -126,8 +126,8 @@ multiple jobs, as long as they use different group names.
126126
import torch.distributed as dist
127127

128128
# Rank will be assigned automatically if unspecified
129-
dist.init_process_group(init_method='file:///mnt/nfs/sharedfile', world_size=4,
130-
group_name=args.group)
129+
dist.init_process_group(backend, init_method='file:///mnt/nfs/sharedfile',
130+
world_size=4, group_name=args.group)
131131

132132
Environment variable initialization
133133
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

docs/source/index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ PyTorch is an optimized tensor library for deep learning using GPUs and CPUs.
3131
nn
3232
optim
3333
torch.autograd <autograd>
34+
torch.distributed <distributed>
3435
torch.distributions <distributions>
3536
torch.multiprocessing <multiprocessing>
36-
torch.distributed <distributed>
3737
bottleneck
3838
checkpoint
3939
cpp_extension

test/test_cuda.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -773,7 +773,7 @@ def advance(gen, end):
773773
# interlace
774774
torch.cuda.empty_cache()
775775
gen0 = self._test_memory_stats_generator(self, device=0, N=35)
776-
gen1 = self._test_memory_stats_generator(self, device=1, N=35)
776+
gen1 = self._test_memory_stats_generator(self, device=torch.device('cuda:1'), N=35)
777777
end0 = end1 = False
778778
while not (end0 and end1):
779779
end0 = advance(gen0, end0)
@@ -782,7 +782,7 @@ def advance(gen, end):
782782
# semi-random order
783783
torch.cuda.empty_cache()
784784
gen0 = self._test_memory_stats_generator(self, device=0, N=35)
785-
gen1 = self._test_memory_stats_generator(self, device=1, N=35)
785+
gen1 = self._test_memory_stats_generator(self, device=torch.device('cuda:1'), N=35)
786786
end0 = end1 = False
787787

788788
while not (end0 and end1):

torch/cuda/__init__.py

Lines changed: 35 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from torch._six import raise_from
1919
from subprocess import Popen, PIPE
2020
from multiprocessing.util import register_after_fork as _register_after_fork
21+
from ._utils import _get_device_index
2122

2223
_initialized = False
2324
_queued_calls = [] # don't invoke these until initialization occurs
@@ -211,12 +212,12 @@ class device(object):
211212
r"""Context-manager that changes the selected device.
212213
213214
Arguments:
214-
idx (int): device index to select. It's a no-op if this argument
215-
is negative.
215+
device (torch.device or int): device index to select. It's a no-op if
216+
this argument is a negative integer or ``None``.
216217
"""
217218

218-
def __init__(self, idx):
219-
self.idx = int(idx)
219+
def __init__(self, device):
220+
self.idx = _get_device_index(device, optional=True)
220221
self.prev_idx = -1
221222

222223
def __enter__(self):
@@ -255,9 +256,10 @@ def set_device(device):
255256
cases it's better to use ``CUDA_VISIBLE_DEVICES`` environmental variable.
256257
257258
Arguments:
258-
device (int): selected device. This function is a no-op if this
259-
argument is negative.
259+
device (torch.device or int): selected device. This function is a no-op
260+
if this argument is negative.
260261
"""
262+
device = _get_device_index(device)
261263
if device >= 0:
262264
torch._C._cuda_setDevice(device)
263265

@@ -266,8 +268,10 @@ def get_device_name(device):
266268
r"""Gets the name of a device.
267269
268270
Arguments:
269-
device (int): device for which to return the name. This function is a
270-
no-op if this argument is negative.
271+
device (torch.device or int, optional): device for which to return the
272+
name. This function is a no-op if this argument is a negative
273+
integer. Uses the current device, given by :meth:`~torch.cuda.current_device`,
274+
if :attr:`device` is ``None`` (default).
271275
"""
272276
return get_device_properties(device).name
273277

@@ -276,8 +280,12 @@ def get_device_capability(device):
276280
r"""Gets the cuda capability of a device.
277281
278282
Arguments:
279-
device (int): device for which to return the name. This function is a
280-
no-op if this argument is negative.
283+
device (torch.device or int, optional): device for which to return the
284+
device capability. This function is a no-op if this argument is
285+
a negative integer. Uses the current device, given by
286+
:meth:`~torch.cuda.current_device`, if :attr:`device` is ``None``
287+
(default).
288+
281289
Returns:
282290
tuple(int, int): the major and minor cuda capability of the device
283291
"""
@@ -288,6 +296,7 @@ def get_device_capability(device):
288296
def get_device_properties(device):
289297
if not _initialized:
290298
init() # will define _get_device_properties and _CudaDeviceProperties
299+
device = _get_device_index(device, optional=True)
291300
if device < 0 or device >= device_count():
292301
raise AssertionError("Invalid device id")
293302
return _get_device_properties(device)
@@ -370,19 +379,17 @@ def memory_allocated(device=None):
370379
device.
371380
372381
Arguments:
373-
device (int, optional): selected device. Returns statistic for the
374-
current device, given by
375-
:meth:`~torch.cuda.current_device`, if
376-
:attr:`device` is ``None`` (default).
382+
device (torch.device or int, optional): selected device. Returns
383+
statistic for the current device, given by :meth:`~torch.cuda.current_device`,
384+
if :attr:`device` is ``None`` (default).
377385
378386
.. note::
379387
This is likely less than the amount shown in `nvidia-smi` since some
380388
unused memory can be held by the caching allocator and some context
381389
needs to be created on GPU. See :ref:`cuda-memory-management` for more
382390
details about GPU memory management.
383391
"""
384-
if device is None:
385-
device = current_device()
392+
device = _get_device_index(device, optional=True)
386393
return torch._C._cuda_memoryAllocated(device)
387394

388395

@@ -391,17 +398,15 @@ def max_memory_allocated(device=None):
391398
device.
392399
393400
Arguments:
394-
device (int, optional): selected device. Returns statistic for the
395-
current device, given by
396-
:meth:`~torch.cuda.current_device`, if
397-
:attr:`device` is ``None`` (default).
401+
device (torch.device or int, optional): selected device. Returns
402+
statistic for the current device, given by :meth:`~torch.cuda.current_device`,
403+
if :attr:`device` is ``None`` (default).
398404
399405
.. note::
400406
See :ref:`cuda-memory-management` for more details about GPU memory
401407
management.
402408
"""
403-
if device is None:
404-
device = current_device()
409+
device = _get_device_index(device, optional=True)
405410
return torch._C._cuda_maxMemoryAllocated(device)
406411

407412

@@ -410,17 +415,15 @@ def memory_cached(device=None):
410415
for a given device.
411416
412417
Arguments:
413-
device (int, optional): selected device. Returns statistic for the
414-
current device, given by
415-
:meth:`~torch.cuda.current_device`, if
416-
:attr:`device` is ``None`` (default).
418+
device (torch.device or int, optional): selected device. Returns
419+
statistic for the current device, given by :meth:`~torch.cuda.current_device`,
420+
if :attr:`device` is ``None`` (default).
417421
418422
.. note::
419423
See :ref:`cuda-memory-management` for more details about GPU memory
420424
management.
421425
"""
422-
if device is None:
423-
device = current_device()
426+
device = _get_device_index(device, optional=True)
424427
return torch._C._cuda_memoryCached(device)
425428

426429

@@ -429,17 +432,15 @@ def max_memory_cached(device=None):
429432
for a given device.
430433
431434
Arguments:
432-
device (int, optional): selected device. Returns statistic for the
433-
current device, given by
434-
:meth:`~torch.cuda.current_device`, if
435-
:attr:`device` is ``None`` (default).
435+
device (torch.device or int, optional): selected device. Returns
436+
statistic for the current device, given by :meth:`~torch.cuda.current_device`,
437+
if :attr:`device` is ``None`` (default).
436438
437439
.. note::
438440
See :ref:`cuda-memory-management` for more details about GPU memory
439441
management.
440442
"""
441-
if device is None:
442-
device = current_device()
443+
device = _get_device_index(device, optional=True)
443444
return torch._C._cuda_maxMemoryCached(device)
444445

445446

torch/cuda/_utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import torch
2+
3+
4+
def _get_device_index(device, optional=False):
5+
r"""Gets the device index from :attr:`device`, which can be a torch.device
6+
object, a Python integer, or ``None``.
7+
8+
If :attr:`device` is a torch.device object, returns the device index if it
9+
is a CUDA device. Note that for CUDA device without sepecified index, i.e.,
10+
``torch.devie('cuda')``, this will return the current default CUDA device if
11+
:attr:`optional` is ``True``.
12+
13+
If :attr:`device` is a Python interger, it is returned as is.
14+
15+
If :attr:`device` is ``None``, this will return the current default CUDA
16+
device if :attr:`optional` is ``True``.
17+
"""
18+
if isinstance(device, torch.device):
19+
dev_type = device.type
20+
if device.type != 'cuda':
21+
raise ValueError('Expected a cuda device, but got: {}'.format(device))
22+
device_idx = device.index
23+
else:
24+
device_idx = device
25+
if device_idx is None:
26+
if optional:
27+
# default cuda device index
28+
return torch.cuda.current_device()
29+
else:
30+
raise ValueError('Expected a cuda device with sepecified index or '
31+
'an integer, but got: '.format(device))
32+
return device_idx

torch/cuda/streams.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,30 @@
11
import ctypes
22
import torch
33
from . import cudart, check_error, cudaStatus
4+
from ._utils import _get_device_index
45

56

67
class Stream(torch._C._CudaStreamBase):
7-
"""Wrapper around a CUDA stream.
8+
r"""Wrapper around a CUDA stream.
89
910
A CUDA stream is a linear sequence of execution that belongs to a specific
1011
device, independent from other streams. See :ref:`cuda-semantics` for
1112
details.
1213
1314
Arguments:
14-
device(int, optional): a device on which to allocate the Stream.
15+
device(torch.device or int, optional): a device on which to allocate
16+
the stream. If :attr:`device` is ``None`` (default) or a negative
17+
integer, this will use the current device.
1518
priority(int, optional): priority of the stream. Lower numbers
1619
represent higher priorities.
1720
"""
1821

19-
def __new__(cls, device=-1, priority=0, **kwargs):
22+
def __new__(cls, device=None, priority=0, **kwargs):
2023
with torch.cuda.device(device):
2124
return super(Stream, cls).__new__(cls, priority=priority, **kwargs)
2225

2326
def wait_event(self, event):
24-
"""Makes all future work submitted to the stream wait for an event.
27+
r"""Makes all future work submitted to the stream wait for an event.
2528
2629
Arguments:
2730
event (Event): an event to wait for.
@@ -38,7 +41,7 @@ def wait_event(self, event):
3841
check_error(cudart().cudaStreamWaitEvent(self, event, ctypes.c_int(0)))
3942

4043
def wait_stream(self, stream):
41-
"""Synchronizes with another stream.
44+
r"""Synchronizes with another stream.
4245
4346
All future work submitted to this stream will wait until all kernels
4447
submitted to a given stream at the time of call complete.
@@ -52,7 +55,7 @@ def wait_stream(self, stream):
5255
self.wait_event(stream.record_event())
5356

5457
def record_event(self, event=None):
55-
"""Records an event.
58+
r"""Records an event.
5659
5760
Arguments:
5861
event (Event, optional): event to record. If not given, a new one
@@ -67,7 +70,7 @@ def record_event(self, event=None):
6770
return event
6871

6972
def query(self):
70-
"""Checks if all the work submitted has been completed.
73+
r"""Checks if all the work submitted has been completed.
7174
7275
Returns:
7376
A boolean indicating if all kernels in this stream are completed.
@@ -79,7 +82,7 @@ def query(self):
7982
return True
8083

8184
def synchronize(self):
82-
"""Wait for all the kernels in this stream to complete.
85+
r"""Wait for all the kernels in this stream to complete.
8386
8487
.. note:: This is a wrapper around ``cudaStreamSynchronize()``: see
8588
`CUDA documentation`_ for more info.
@@ -126,7 +129,7 @@ class EventHandle(ctypes.Structure):
126129

127130

128131
class Event(object):
129-
"""Wrapper around CUDA event.
132+
r"""Wrapper around CUDA event.
130133
131134
Arguments:
132135
enable_timing (bool): indicates if the event should measure time
@@ -165,19 +168,19 @@ def __del__(self):
165168
del self._as_parameter_
166169

167170
def record(self, stream=None):
168-
"""Records the event in a given stream."""
171+
r"""Records the event in a given stream."""
169172
if stream is None:
170173
stream = torch.cuda.current_stream()
171174
stream.record_event(self)
172175

173176
def wait(self, stream=None):
174-
"""Makes a given stream wait for the event."""
177+
r"""Makes a given stream wait for the event."""
175178
if stream is None:
176179
stream = torch.cuda.current_stream()
177180
stream.wait_event(self)
178181

179182
def query(self):
180-
"""Checks if the event has been recorded.
183+
r"""Checks if the event has been recorded.
181184
182185
Returns:
183186
A boolean indicating if the event has been recorded.
@@ -189,18 +192,18 @@ def query(self):
189192
return True
190193

191194
def elapsed_time(self, end_event):
192-
"""Returns the time elapsed before the event was recorded."""
195+
r"""Returns the time elapsed before the event was recorded."""
193196
time_ms = ctypes.c_float()
194197
check_error(cudart().cudaEventElapsedTime(
195198
ctypes.byref(time_ms), self, end_event))
196199
return time_ms.value
197200

198201
def synchronize(self):
199-
"""Synchronizes with the event."""
202+
r"""Synchronizes with the event."""
200203
check_error(cudart().cudaEventSynchronize(self))
201204

202205
def ipc_handle(self):
203-
"""Returns an IPC handle of this event."""
206+
r"""Returns an IPC handle of this event."""
204207
handle = EventHandle()
205208
check_error(cudart().cudaIpcGetEventHandle(ctypes.byref(handle), self))
206209
return handle

0 commit comments

Comments
 (0)