55from .scatter_gather import scatter_kwargs , gather
66from .replicate import replicate
77from .parallel_apply import parallel_apply
8+ from torch .cuda ._utils import _get_device_index
89
910
1011def _check_balance (device_ids ):
@@ -13,7 +14,7 @@ def _check_balance(device_ids):
1314 has less than 75% of the memory or cores of GPU {}. You can do so by setting
1415 the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES
1516 environment variable."""
16-
17+ device_ids = list ( map ( lambda x : _get_device_index ( x , True ), device_ids ))
1718 dev_props = [torch .cuda .get_device_properties (i ) for i in device_ids ]
1819
1920 def warn_imbalance (get_prop ):
@@ -77,9 +78,9 @@ class DataParallel(Module):
7778
7879
7980 Args:
80- module: module to be parallelized
81- device_ids: CUDA devices (default: all devices)
82- output_device: device location of output (default: device_ids[0])
81+ module (Module) : module to be parallelized
82+ device_ids (list of int or torch.device) : CUDA devices (default: all devices)
83+ output_device (int or torch.device) : device location of output (default: device_ids[0])
8384
8485 Attributes:
8586 module (Module): the module to be parallelized
@@ -104,10 +105,11 @@ def __init__(self, module, device_ids=None, output_device=None, dim=0):
104105 device_ids = list (range (torch .cuda .device_count ()))
105106 if output_device is None :
106107 output_device = device_ids [0 ]
108+
107109 self .dim = dim
108110 self .module = module
109- self .device_ids = device_ids
110- self .output_device = output_device
111+ self .device_ids = list ( map ( lambda x : _get_device_index ( x , True ), device_ids ))
112+ self .output_device = _get_device_index ( output_device , True )
111113
112114 _check_balance (self .device_ids )
113115
@@ -143,10 +145,10 @@ def data_parallel(module, inputs, device_ids=None, output_device=None, dim=0, mo
143145 This is the functional version of the DataParallel module.
144146
145147 Args:
146- module: the module to evaluate in parallel
147- inputs: inputs to the module
148- device_ids: GPU ids on which to replicate module
149- output_device: GPU location of the output Use -1 to indicate the CPU.
148+ module (Module) : the module to evaluate in parallel
149+ inputs (tensor) : inputs to the module
150+ device_ids (list of int or torch.device) : GPU ids on which to replicate module
151+ output_device (list of int or torch.device) : GPU location of the output Use -1 to indicate the CPU.
150152 (default: device_ids[0])
151153 Returns:
152154 a Tensor containing the result of module(input) located on
0 commit comments