Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -1479,7 +1479,7 @@ def _test_DistributedDataParallel_SyncBatchNorm(self, gpu_subset, rank, output_d
model_gpu.cuda(gpu_subset[0])

# DDP training setup
model_DDP = nn.utils.convert_sync_batchnorm(copy.deepcopy(model))
model_DDP = nn.SyncBatchNorm.convert_sync_batchnorm(copy.deepcopy(model))
model_DDP.cuda(gpu_subset[0])
model_DDP = nn.parallel.DistributedDataParallel(
model_DDP, device_ids=gpu_subset
Expand Down
46 changes: 45 additions & 1 deletion torch/nn/modules/batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ class SyncBatchNorm(_BatchNorm):
or Spatio-temporal Batch Normalization.

Currently SyncBatchNorm only supports DistributedDataParallel with single GPU per process. Use
torch.nn.utils.convert_sync_batchnorm() to convert BatchNorm layer to SyncBatchNorm before wrapping
torch.nn.SyncBatchNorm.convert_sync_batchnorm() to convert BatchNorm layer to SyncBatchNorm before wrapping
Network with DDP.

Args:
Expand Down Expand Up @@ -458,3 +458,47 @@ def forward(self, input):
return sync_batch_norm.apply(
input, self.weight, self.bias, self.running_mean, self.running_var,
self.eps, exponential_average_factor, process_group, world_size)

@classmethod
def convert_sync_batchnorm(cls, module, process_group=None):
r"""Helper function to convert `torch.nn.BatchNormND` layer in the model to
`torch.nn.SyncBatchNorm` layer.

Args:
module (nn.Module): containing module
process_group (optional): process group to scope synchronization,
default is the whole world

Returns:
The original module with the converted `torch.nn.SyncBatchNorm` layer

Example::

>>> # Network with nn.BatchNorm layer
>>> module = torch.nn.Sequential(
>>> torch.nn.Linear(20, 100),
>>> torch.nn.BatchNorm1d(100)
>>> ).cuda()
>>> # creating process group (optional)
>>> # process_ids is a list of int identifying rank ids.
>>> process_group = torch.distributed.new_group(process_ids)
>>> sync_bn_module = convert_sync_batchnorm(module, process_group)

"""
module_output = module
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
module_output = torch.nn.SyncBatchNorm(module.num_features,
module.eps, module.momentum,
module.affine,
module.track_running_stats,
process_group)
if module.affine:
module_output.weight.data = module.weight.data.clone().detach()
module_output.bias.data = module.bias.data.clone().detach()
module_output.running_mean = module.running_mean
module_output.running_var = module.running_var
module_output.num_batches_tracked = module.num_batches_tracked
for name, child in module.named_children():
module_output.add_module(name, cls.convert_sync_batchnorm(child))
del module
return module_output
1 change: 0 additions & 1 deletion torch/nn/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,3 @@
from .weight_norm import weight_norm, remove_weight_norm # noqa: F401
from .convert_parameters import parameters_to_vector, vector_to_parameters # noqa: F401
from .spectral_norm import spectral_norm, remove_spectral_norm # noqa: F401
from .sync_batch_norm import convert_sync_batchnorm # noqa: F401
45 changes: 0 additions & 45 deletions torch/nn/utils/sync_batch_norm.py

This file was deleted.