1818from torch ._six import raise_from
1919from subprocess import Popen , PIPE
2020from multiprocessing .util import register_after_fork as _register_after_fork
21+ from ._utils import _get_device_index
2122
2223_initialized = False
2324_queued_calls = [] # don't invoke these until initialization occurs
@@ -211,12 +212,12 @@ class device(object):
211212 r"""Context-manager that changes the selected device.
212213
213214 Arguments:
214- idx ( int): device index to select. It's a no-op if this argument
215- is negative.
215+ device (torch.device or int): device index to select. It's a no-op if
216+ this argument is a negative integer or ``None`` .
216217 """
217218
218- def __init__ (self , idx ):
219- self .idx = int ( idx )
219+ def __init__ (self , device ):
220+ self .idx = _get_device_index ( device , optional = True )
220221 self .prev_idx = - 1
221222
222223 def __enter__ (self ):
@@ -255,9 +256,10 @@ def set_device(device):
255256 cases it's better to use ``CUDA_VISIBLE_DEVICES`` environmental variable.
256257
257258 Arguments:
258- device (int): selected device. This function is a no-op if this
259- argument is negative.
259+ device (torch.device or int): selected device. This function is a no-op
260+ if this argument is negative.
260261 """
262+ device = _get_device_index (device )
261263 if device >= 0 :
262264 torch ._C ._cuda_setDevice (device )
263265
@@ -266,8 +268,10 @@ def get_device_name(device):
266268 r"""Gets the name of a device.
267269
268270 Arguments:
269- device (int): device for which to return the name. This function is a
270- no-op if this argument is negative.
271+ device (torch.device or int, optional): device for which to return the
272+ name. This function is a no-op if this argument is a negative
273+ integer. Uses the current device, given by :meth:`~torch.cuda.current_device`,
274+ if :attr:`device` is ``None`` (default).
271275 """
272276 return get_device_properties (device ).name
273277
@@ -276,8 +280,12 @@ def get_device_capability(device):
276280 r"""Gets the cuda capability of a device.
277281
278282 Arguments:
279- device (int): device for which to return the name. This function is a
280- no-op if this argument is negative.
283+ device (torch.device or int, optional): device for which to return the
284+ device capability. This function is a no-op if this argument is
285+ a negative integer. Uses the current device, given by
286+ :meth:`~torch.cuda.current_device`, if :attr:`device` is ``None``
287+ (default).
288+
281289 Returns:
282290 tuple(int, int): the major and minor cuda capability of the device
283291 """
@@ -288,6 +296,7 @@ def get_device_capability(device):
288296def get_device_properties (device ):
289297 if not _initialized :
290298 init () # will define _get_device_properties and _CudaDeviceProperties
299+ device = _get_device_index (device , optional = True )
291300 if device < 0 or device >= device_count ():
292301 raise AssertionError ("Invalid device id" )
293302 return _get_device_properties (device )
@@ -370,19 +379,17 @@ def memory_allocated(device=None):
370379 device.
371380
372381 Arguments:
373- device (int, optional): selected device. Returns statistic for the
374- current device, given by
375- :meth:`~torch.cuda.current_device`, if
376- :attr:`device` is ``None`` (default).
382+ device (torch.device or int, optional): selected device. Returns
383+ statistic for the current device, given by :meth:`~torch.cuda.current_device`,
384+ if :attr:`device` is ``None`` (default).
377385
378386 .. note::
379387 This is likely less than the amount shown in `nvidia-smi` since some
380388 unused memory can be held by the caching allocator and some context
381389 needs to be created on GPU. See :ref:`cuda-memory-management` for more
382390 details about GPU memory management.
383391 """
384- if device is None :
385- device = current_device ()
392+ device = _get_device_index (device , optional = True )
386393 return torch ._C ._cuda_memoryAllocated (device )
387394
388395
@@ -391,17 +398,15 @@ def max_memory_allocated(device=None):
391398 device.
392399
393400 Arguments:
394- device (int, optional): selected device. Returns statistic for the
395- current device, given by
396- :meth:`~torch.cuda.current_device`, if
397- :attr:`device` is ``None`` (default).
401+ device (torch.device or int, optional): selected device. Returns
402+ statistic for the current device, given by :meth:`~torch.cuda.current_device`,
403+ if :attr:`device` is ``None`` (default).
398404
399405 .. note::
400406 See :ref:`cuda-memory-management` for more details about GPU memory
401407 management.
402408 """
403- if device is None :
404- device = current_device ()
409+ device = _get_device_index (device , optional = True )
405410 return torch ._C ._cuda_maxMemoryAllocated (device )
406411
407412
@@ -410,17 +415,15 @@ def memory_cached(device=None):
410415 for a given device.
411416
412417 Arguments:
413- device (int, optional): selected device. Returns statistic for the
414- current device, given by
415- :meth:`~torch.cuda.current_device`, if
416- :attr:`device` is ``None`` (default).
418+ device (torch.device or int, optional): selected device. Returns
419+ statistic for the current device, given by :meth:`~torch.cuda.current_device`,
420+ if :attr:`device` is ``None`` (default).
417421
418422 .. note::
419423 See :ref:`cuda-memory-management` for more details about GPU memory
420424 management.
421425 """
422- if device is None :
423- device = current_device ()
426+ device = _get_device_index (device , optional = True )
424427 return torch ._C ._cuda_memoryCached (device )
425428
426429
@@ -429,17 +432,15 @@ def max_memory_cached(device=None):
429432 for a given device.
430433
431434 Arguments:
432- device (int, optional): selected device. Returns statistic for the
433- current device, given by
434- :meth:`~torch.cuda.current_device`, if
435- :attr:`device` is ``None`` (default).
435+ device (torch.device or int, optional): selected device. Returns
436+ statistic for the current device, given by :meth:`~torch.cuda.current_device`,
437+ if :attr:`device` is ``None`` (default).
436438
437439 .. note::
438440 See :ref:`cuda-memory-management` for more details about GPU memory
439441 management.
440442 """
441- if device is None :
442- device = current_device ()
443+ device = _get_device_index (device , optional = True )
443444 return torch ._C ._cuda_maxMemoryCached (device )
444445
445446
0 commit comments