Skip to content

Commit f5fdfa2

Browse files
committed
enable channels_last_3d on SyncBatchNorm
1 parent a8561c4 commit f5fdfa2

File tree

4 files changed

+60
-9
lines changed

4 files changed

+60
-9
lines changed

aten/src/ATen/native/cuda/Normalization.cu

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,11 @@ bool is_mixed_type(const Tensor& input, const Args&... parameters) {
4848
}
4949

5050
inline bool batch_norm_use_channels_last_kernels(const at::Tensor& self) {
51-
return (self.is_contiguous(at::MemoryFormat::ChannelsLast) ||
52-
(self.is_contiguous() && self.strides()[1] == 1));
51+
return (
52+
self.is_contiguous(at::MemoryFormat::ChannelsLast) ||
53+
self.is_contiguous(at::MemoryFormat::ChannelsLast3d) ||
54+
(self.is_contiguous() && self.strides()[1] == 1)
55+
);
5356
}
5457

5558
enum class Impl {

test/test_nn.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10269,16 +10269,16 @@ def test_sync_batchnorm_accuracy_cuda(self):
1026910269
# fwd: torch.batch_norm_stats, torch.batch_norm_gather_stats_with_counts, torch.batch_norm_elemt
1027010270
# bwd: torch.batch_norm_backward_reduce, torch.batch_norm_backward_elemt
1027110271

10272-
def _batch_norm_stats(data):
10272+
def _batch_norm_stats(data, memory_format, mean_axes):
1027310273
mean1, _ = torch.batch_norm_stats(data, 1e-5)
10274-
mean2, _ = torch.batch_norm_stats(data.to(memory_format=torch.channels_last), 1e-5)
10275-
mean_ref = torch.mean(data, (0, 2, 3), keepdim=False)
10274+
mean2, _ = torch.batch_norm_stats(data.to(memory_format=memory_format), 1e-5)
10275+
mean_ref = torch.mean(data, mean_axes, keepdim=False)
1027610276

1027710277
self.assertEqual(mean_ref, mean1)
1027810278
self.assertEqual(mean_ref, mean2)
1027910279

10280-
data = torch.randn(1, 96, 112, 112, dtype=torch.float, device='cuda')
10281-
_batch_norm_stats(data)
10280+
_batch_norm_stats(torch.randn(1, 96, 112, 112, dtype=torch.float, device='cuda'), torch.channels_last, (0, 2, 3))
10281+
_batch_norm_stats(torch.randn(1, 96, 112, 112, 112, dtype=torch.float, device='cuda'), torch.channels_last_3d, (0, 2, 3, 4))
1028210282

1028310283
def test_flatten(self):
1028410284
tensor_input = torch.randn(2, 1, 2, 3)

torch/nn/modules/_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ class SyncBatchNorm(Function):
77

88
@staticmethod
99
def forward(self, input, weight, bias, running_mean, running_var, eps, momentum, process_group, world_size):
10-
if not input.is_contiguous(memory_format=torch.channels_last):
10+
if not (input.is_contiguous(memory_format=torch.channels_last) or input.is_contiguous(memory_format=torch.channels_last_3d)):
1111
input = input.contiguous()
1212
if weight is not None:
1313
weight = weight.contiguous()
@@ -104,7 +104,7 @@ def forward(self, input, weight, bias, running_mean, running_var, eps, momentum,
104104

105105
@staticmethod
106106
def backward(self, grad_output):
107-
if not grad_output.is_contiguous(memory_format=torch.channels_last):
107+
if not (grad_output.is_contiguous(memory_format=torch.channels_last) or grad_output.is_contiguous(memory_format=torch.channels_last_3d)):
108108
grad_output = grad_output.contiguous()
109109
saved_input, weight, mean, invstd, count_tensor = self.saved_tensors
110110
grad_input = grad_weight = grad_bias = None

torch/testing/_internal/distributed/distributed_test.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5319,6 +5319,54 @@ def test_post_localSGD_optimizer_step_reload(self):
53195319
tmp_file
53205320
)
53215321

5322+
@sandcastle_skip_if(
5323+
BACKEND not in DistTestCases.backend_feature["ddp"],
5324+
f"The {BACKEND} backend does not support DistributedDataParallel"
5325+
)
5326+
@skip_if_no_gpu
5327+
def test_DistributedDataParallel_SyncBatchNorm_Channels_Last_3D(self):
5328+
group, group_id, rank = self._init_global_test()
5329+
num_processes = dist.get_world_size()
5330+
local_bs = 2
5331+
bs_offset = int(rank * 2)
5332+
global_bs = int(num_processes * 2)
5333+
5334+
model = ONLY_SBN_NET
5335+
model_gpu = copy.deepcopy(model).cuda(rank)
5336+
model_DDP = nn.parallel.DistributedDataParallel(
5337+
model_gpu, device_ids=[rank]
5338+
)
5339+
5340+
memory_format = torch.channels_last_3d
5341+
input_gpu = (
5342+
torch.randn(global_bs, 2, 4, 4, 4, dtype=torch.float)
5343+
.cuda(rank)
5344+
.to(memory_format=memory_format)
5345+
)
5346+
target_gpu = (
5347+
torch.randn(global_bs, 2, 4, 4, 4, dtype=torch.float)
5348+
.cuda(rank)
5349+
.to(memory_format=memory_format)
5350+
)
5351+
loss = nn.MSELoss()
5352+
5353+
# check two model parameters over 5 iterations
5354+
self._test_DDP_niter(
5355+
model_gpu,
5356+
model_DDP,
5357+
input_gpu,
5358+
target_gpu,
5359+
loss,
5360+
local_bs,
5361+
rank,
5362+
global_bs,
5363+
True,
5364+
bs_offset,
5365+
dist.get_world_size(),
5366+
memory_format=memory_format,
5367+
)
5368+
self._barrier()
5369+
53225370
@sandcastle_skip_if(
53235371
BACKEND not in DistTestCases.backend_feature["ddp"],
53245372
f"The {BACKEND} backend does not support DistributedDataParallel"

0 commit comments

Comments
 (0)