Skip to content

Commit f7a7868

Browse files
zhangliliangfacebook-github-bot
authored andcommitted
add process_group in convert_sync_batchnorm (#19240)
Summary: In line 508. convert_sync_batchnorm is called recursively to convert the bn to syncbn, thus the process_group also should be passed in the function. Pull Request resolved: #19240 Differential Revision: D15240318 Pulled By: ezyang fbshipit-source-id: 0fc9e856392824814991e5e9e8f9513d57f311af
1 parent 2356fac commit f7a7868

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

test/test_distributed.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,15 @@
2020
from torch._utils_internal import TEST_MASTER_ADDR as MASTER_ADDR
2121
from torch._utils_internal import TEST_MASTER_PORT as MASTER_PORT
2222

23+
try:
24+
import torchvision
25+
HAS_TORCHVISION = True
26+
except ImportError:
27+
HAS_TORCHVISION = False
28+
29+
30+
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
31+
2332
BACKEND = os.environ["BACKEND"]
2433
TEMP_DIR = os.environ["TEMP_DIR"]
2534
INIT_METHOD = os.getenv("INIT_METHOD", "env://")
@@ -1528,6 +1537,19 @@ def test_DistributedDataParallel_SyncBatchNorm(self):
15281537
gpus = list(map(lambda i: torch.device('cuda:' + str(i)), gpus))
15291538
self._test_DistributedDataParallel_SyncBatchNorm(gpu_subset=gpus, rank=rank, output_device=torch.device('cuda'))
15301539

1540+
@skipIfNoTorchVision
1541+
def test_SyncBatchNorm_process_group(self):
1542+
# When adopting `convert_sync_batchnorm` to convert a `nn.modules`,
1543+
# it need to recursively pass the `process_group` in the module when the `SyncBatchNorm`
1544+
# is nested in a sub-module or sub-sub-module (e.g. resnet50 in torchvision.models).
1545+
1546+
process_ids = 0
1547+
process_group = torch.distributed.new_group([process_ids])
1548+
res50_model = torchvision.models.resnet50()
1549+
res50_model_sync = nn.SyncBatchNorm.convert_sync_batchnorm(copy.deepcopy(res50_model), process_group)
1550+
process_group_sync = res50_model_sync.layer1[0].bn1.process_group
1551+
self.assertEqual(process_group_sync, process_group)
1552+
15311553
if BACKEND == "gloo" or BACKEND == "nccl":
15321554
WORLD_SIZE = os.environ["WORLD_SIZE"]
15331555

torch/nn/modules/batchnorm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,6 @@ def convert_sync_batchnorm(cls, module, process_group=None):
505505
module_output.running_var = module.running_var
506506
module_output.num_batches_tracked = module.num_batches_tracked
507507
for name, child in module.named_children():
508-
module_output.add_module(name, cls.convert_sync_batchnorm(child))
508+
module_output.add_module(name, cls.convert_sync_batchnorm(child, process_group))
509509
del module
510510
return module_output

0 commit comments

Comments
 (0)