Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions test/quantization/test_workflow_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1536,6 +1536,21 @@ def forward(self, x):
isinstance(fused_model.conv.bn, nn.SyncBatchNorm),
"Expected BN to be converted to SyncBN")

def test_syncbn_preserves_qconfig(self):
"""
Makes sure that if a BatchNorm is not fused and a qconfig exists,
convering the module to SyncBatchNorm preserves the qconfig.
"""
m = nn.Sequential(
nn.Conv2d(1, 1, 1),
nn.BatchNorm2d(1),
)
m[1].qconfig = torch.quantization.default_qconfig
m = torch.nn.SyncBatchNorm.convert_sync_batchnorm(m)
self.assertTrue(
hasattr(m[1], "qconfig"),
"missing qconfig after SyncBatchNorm conversion")

@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
@override_qengines
Expand Down
14 changes: 8 additions & 6 deletions torch/nn/modules/batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def forward(self, input: Tensor) -> Tensor:
else: # use exponential moving average
exponential_average_factor = self.momentum

r"""
r"""
Decide whether the mini-batch stats should be used for normalization rather than the buffers.
Mini-batch stats are used in training mode, and in eval mode when buffers are None.
"""
Expand Down Expand Up @@ -185,7 +185,7 @@ class BatchNorm1d(_BatchNorm):
track_running_stats: a boolean value that when set to ``True``, this
module tracks the running mean and variance, and when set to ``False``,
this module does not track such statistics, and initializes statistics
buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
When these buffers are ``None``, this module always uses batch statistics.
in both training and eval modes. Default: ``True``

Expand Down Expand Up @@ -258,7 +258,7 @@ class BatchNorm2d(_BatchNorm):
track_running_stats: a boolean value that when set to ``True``, this
module tracks the running mean and variance, and when set to ``False``,
this module does not track such statistics, and initializes statistics
buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
When these buffers are ``None``, this module always uses batch statistics.
in both training and eval modes. Default: ``True``

Expand Down Expand Up @@ -332,7 +332,7 @@ class BatchNorm3d(_BatchNorm):
track_running_stats: a boolean value that when set to ``True``, this
module tracks the running mean and variance, and when set to ``False``,
this module does not track such statistics, and initializes statistics
buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
When these buffers are ``None``, this module always uses batch statistics.
in both training and eval modes. Default: ``True``

Expand Down Expand Up @@ -414,7 +414,7 @@ class SyncBatchNorm(_BatchNorm):
track_running_stats: a boolean value that when set to ``True``, this
module tracks the running mean and variance, and when set to ``False``,
this module does not track such statistics, and initializes statistics
buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
When these buffers are ``None``, this module always uses batch statistics.
in both training and eval modes. Default: ``True``
process_group: synchronization of stats happen within each process group
Expand Down Expand Up @@ -493,7 +493,7 @@ def forward(self, input: Tensor) -> Tensor:
else: # use exponential moving average
exponential_average_factor = self.momentum

r"""
r"""
Decide whether the mini-batch stats should be used for normalization rather than the buffers.
Mini-batch stats are used in training mode, and in eval mode when buffers are None.
"""
Expand Down Expand Up @@ -576,6 +576,8 @@ def convert_sync_batchnorm(cls, module, process_group=None):
module_output.running_mean = module.running_mean
module_output.running_var = module.running_var
module_output.num_batches_tracked = module.num_batches_tracked
if hasattr(module, "qconfig"):
module_output.qconfig = module.qconfig
for name, child in module.named_children():
module_output.add_module(name, cls.convert_sync_batchnorm(child, process_group))
del module
Expand Down