Skip to content

Commit f5df685

Browse files
xwang233pytorchmergebot
authored andcommitted
Enable channels_last_3d on SyncBatchNorm (#88401)
This PR enabled the use of fast channels_last kernels on SyncBatchNorm with channels_last_3d memory format. With a small benchmark script here #88021 (comment), on V100, I got master: ``` DDP channels_last=False, run_forward_backward, time: 0.8945400714874268 sec DDP channels_last=True, run_forward_backward, time: 1.4736433029174805 sec ``` This PR: ``` DDP channels_last=False, run_forward_backward, time: 0.8927242755889893 sec DDP channels_last=True, run_forward_backward, time: 0.48697471618652344 sec ``` This PR is a follow-up of #46906 Close #88021 Pull Request resolved: #88401 Approved by: https://github.com/ngimel
1 parent 8023c9d commit f5df685

File tree

4 files changed

+26
-12
lines changed

4 files changed

+26
-12
lines changed

aten/src/ATen/native/cuda/Normalization.cu

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,11 @@ bool is_mixed_type(const Tensor& input, const Args&... parameters) {
4848
}
4949

5050
inline bool batch_norm_use_channels_last_kernels(const at::Tensor& self) {
51-
return (self.is_contiguous(at::MemoryFormat::ChannelsLast) ||
52-
(self.is_contiguous() && self.strides()[1] == 1));
51+
return (
52+
self.is_contiguous(at::MemoryFormat::ChannelsLast) ||
53+
self.is_contiguous(at::MemoryFormat::ChannelsLast3d) ||
54+
(self.is_contiguous() && self.strides()[1] == 1)
55+
);
5356
}
5457

5558
enum class Impl {

test/test_nn.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10283,16 +10283,16 @@ def test_sync_batchnorm_accuracy_cuda(self):
1028310283
# fwd: torch.batch_norm_stats, torch.batch_norm_gather_stats_with_counts, torch.batch_norm_elemt
1028410284
# bwd: torch.batch_norm_backward_reduce, torch.batch_norm_backward_elemt
1028510285

10286-
def _batch_norm_stats(data):
10286+
def _batch_norm_stats(data, memory_format, mean_axes):
1028710287
mean1, _ = torch.batch_norm_stats(data, 1e-5)
10288-
mean2, _ = torch.batch_norm_stats(data.to(memory_format=torch.channels_last), 1e-5)
10289-
mean_ref = torch.mean(data, (0, 2, 3), keepdim=False)
10288+
mean2, _ = torch.batch_norm_stats(data.to(memory_format=memory_format), 1e-5)
10289+
mean_ref = torch.mean(data, mean_axes, keepdim=False)
1029010290

1029110291
self.assertEqual(mean_ref, mean1)
1029210292
self.assertEqual(mean_ref, mean2)
1029310293

10294-
data = torch.randn(1, 96, 112, 112, dtype=torch.float, device='cuda')
10295-
_batch_norm_stats(data)
10294+
_batch_norm_stats(torch.randn(1, 96, 112, 112, dtype=torch.float, device='cuda'), torch.channels_last, (0, 2, 3))
10295+
_batch_norm_stats(torch.randn(1, 96, 112, 112, 112, dtype=torch.float, device='cuda'), torch.channels_last_3d, (0, 2, 3, 4))
1029610296

1029710297
def test_flatten(self):
1029810298
tensor_input = torch.randn(2, 1, 2, 3)

torch/nn/modules/_functions.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@ class SyncBatchNorm(Function):
77

88
@staticmethod
99
def forward(self, input, weight, bias, running_mean, running_var, eps, momentum, process_group, world_size):
10-
if not input.is_contiguous(memory_format=torch.channels_last):
10+
if not (
11+
input.is_contiguous(memory_format=torch.channels_last) or
12+
input.is_contiguous(memory_format=torch.channels_last_3d)
13+
):
1114
input = input.contiguous()
1215
if weight is not None:
1316
weight = weight.contiguous()
@@ -104,7 +107,10 @@ def forward(self, input, weight, bias, running_mean, running_var, eps, momentum,
104107

105108
@staticmethod
106109
def backward(self, grad_output):
107-
if not grad_output.is_contiguous(memory_format=torch.channels_last):
110+
if not (
111+
grad_output.is_contiguous(memory_format=torch.channels_last) or
112+
grad_output.is_contiguous(memory_format=torch.channels_last_3d)
113+
):
108114
grad_output = grad_output.contiguous()
109115
saved_input, weight, mean, invstd, count_tensor = self.saved_tensors
110116
grad_input = grad_weight = grad_bias = None

torch/testing/_internal/distributed/distributed_test.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5324,6 +5324,10 @@ def test_post_localSGD_optimizer_step_reload(self):
53245324
)
53255325
@skip_if_no_gpu
53265326
def test_DistributedDataParallel_SyncBatchNorm_Channels_Last(self):
5327+
self._test_DistributedDataParallel_SyncBatchNorm_with_memory_format(torch.channels_last)
5328+
self._test_DistributedDataParallel_SyncBatchNorm_with_memory_format(torch.channels_last_3d)
5329+
5330+
def _test_DistributedDataParallel_SyncBatchNorm_with_memory_format(self, memory_format):
53275331
group, group_id, rank = self._init_global_test()
53285332
num_processes = dist.get_world_size()
53295333
local_bs = 2
@@ -5336,14 +5340,15 @@ def test_DistributedDataParallel_SyncBatchNorm_Channels_Last(self):
53365340
model_gpu, device_ids=[rank]
53375341
)
53385342

5339-
memory_format = torch.channels_last
5343+
shapes = [global_bs, 2, 4, 4] + ([] if memory_format is torch.channels_last else [4])
5344+
53405345
input_gpu = (
5341-
torch.randn(global_bs, 2, 4, 4, dtype=torch.float)
5346+
torch.randn(*shapes, dtype=torch.float)
53425347
.cuda(rank)
53435348
.to(memory_format=memory_format)
53445349
)
53455350
target_gpu = (
5346-
torch.randn(global_bs, 2, 4, 4, dtype=torch.float)
5351+
torch.randn(*shapes, dtype=torch.float)
53475352
.cuda(rank)
53485353
.to(memory_format=memory_format)
53495354
)

0 commit comments

Comments
 (0)