Skip to content

Commit 48ad454

Browse files
ssnlsoumith
authored andcommitted
Move LayerNorm to ATen; remove tracking_running_stats functionality (#5983)
* move LN to aten; remove tracking_stats functionaility * Address comments about error message and respect cudnn flag for LayerNorm and GroupNorm
1 parent bc1b4c8 commit 48ad454

File tree

5 files changed

+106
-142
lines changed

5 files changed

+106
-142
lines changed

aten/src/ATen/native/Normalization.cpp

Lines changed: 77 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,77 @@ Tensor batch_norm(
7070
running_mean, running_var, training, momentum, eps);
7171
}
7272

73+
Tensor layer_norm(const Tensor& input, IntList normalized_shape,
74+
const Tensor& weight /* optional */, const Tensor& bias /* optional */,
75+
double eps, bool cudnn_enabled) {
76+
77+
int64_t normalized_ndim = normalized_shape.size();
78+
79+
if (normalized_ndim < 1) {
80+
std::stringstream ss;
81+
ss << "Expected normalized_shape to be at least 1-dimensional, i.e., "
82+
<< "containing at least one element, but got normalized_shape="
83+
<< normalized_shape;
84+
throw std::runtime_error(ss.str());
85+
}
86+
87+
if (weight.defined() && !weight.sizes().equals(normalized_shape)) {
88+
std::stringstream ss;
89+
ss << "Expected weight to be of same shape as normalized_shape, but got "
90+
<< "weight of shape " << weight.sizes() << " and normalized_shape="
91+
<< normalized_shape;
92+
throw std::runtime_error(ss.str());
93+
}
94+
95+
if (bias.defined() && !bias.sizes().equals(normalized_shape)) {
96+
std::stringstream ss;
97+
ss << "Expected bias to be of same shape as normalized_shape, but got "
98+
<< "bias of shape " << bias.sizes() << " and normalized_shape="
99+
<< normalized_shape;
100+
throw std::runtime_error(ss.str());
101+
}
102+
103+
auto input_shape = input.sizes();
104+
auto input_ndim = input.dim();
105+
106+
if (input_ndim < normalized_ndim ||
107+
!input_shape.slice(input_ndim - normalized_ndim).equals(normalized_shape)) {
108+
std::stringstream ss;
109+
ss << "Given normalized_shape=" << normalized_shape
110+
<< ", expected input with shape [*";
111+
for (auto size : normalized_shape) {
112+
ss << ", " << size;
113+
}
114+
ss << "], but got input of size" << input_shape;
115+
throw std::runtime_error(ss.str());
116+
}
117+
118+
int64_t n = 1;
119+
for (int64_t i = 0; i < input_ndim - normalized_ndim; i++) {
120+
n *= input_shape[i];
121+
}
122+
123+
// Apply layer norm
124+
auto input_reshaped = input.contiguous().view({1, n, -1});
125+
126+
auto out = at::batch_norm(input_reshaped, {}, {}, {}, {}, true, 0, eps,
127+
cudnn_enabled);
128+
out = out.view(input_shape);
129+
130+
if (weight.defined() && bias.defined()) {
131+
return bias.addcmul(out, weight, 1);
132+
} else if (weight.defined()) {
133+
return out.mul(weight);
134+
} else if (bias.defined()) {
135+
return out.add(bias);
136+
} else {
137+
return out;
138+
}
139+
}
140+
73141
Tensor group_norm(const Tensor& input, int64_t num_groups,
74142
const Tensor& weight /* optional */, const Tensor& bias /* optional */,
75-
double eps) {
143+
double eps, bool cudnn_enabled) {
76144

77145
auto input_shape = input.sizes();
78146
int64_t b = input.size(0);
@@ -81,31 +149,32 @@ Tensor group_norm(const Tensor& input, int64_t num_groups,
81149
if (c % num_groups != 0) {
82150
std::stringstream ss;
83151
ss << "Expected number of channels in input to be divisible by "
84-
<< "num_groups, but got " << input.sizes() << " input and num_groups="
85-
<< num_groups;
152+
<< "num_groups, but got input of shape " << input.sizes() << " and "
153+
<< "num_groups=" << num_groups;
86154
throw std::runtime_error(ss.str());
87155
}
88156

89157
if (weight.defined() && (weight.dim() != 1 || weight.numel() != c)) {
90158
std::stringstream ss;
91159
ss << "Expected weight to be a vector of size equal to the number of "
92-
<< "channels in input, but got " << weight.sizes() << " weight and "
93-
<< input.sizes() << " input";
160+
<< "channels in input, but got weight of shape " << weight.sizes()
161+
<< " and input of shape " << input.sizes();
94162
throw std::runtime_error(ss.str());
95163
}
96164

97165
if (bias.defined() && (bias.dim() != 1 || bias.numel() != c)) {
98166
std::stringstream ss;
99167
ss << "Expected bias to be a vector of size equal to the number of "
100-
<< "channels in input, but got " << bias.sizes() << " bias and "
101-
<< input.sizes() << " input";
168+
<< "channels in input, but got bias of shape " << weight.sizes()
169+
<< " and input of shape " << input.sizes();
102170
throw std::runtime_error(ss.str());
103171
}
104172

105173
// Apply group norm
106174
auto input_reshaped = input.contiguous().view({1, b * num_groups, -1});
107175

108-
auto out = at::batch_norm(input_reshaped, {}, {}, {}, {}, true, 0, eps, true);
176+
auto out = at::batch_norm(input_reshaped, {}, {}, {}, {}, true, 0, eps,
177+
cudnn_enabled);
109178
out = out.view(input_shape);
110179

111180
if (!weight.defined() && !bias.defined()) {

aten/src/ATen/native/native_functions.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@
353353
- func: ger_out(Tensor result, Tensor self, Tensor vec2) -> Tensor
354354
variants: function
355355

356-
- func: group_norm(Tensor input, int64_t num_groups, Tensor? weight={}, Tensor? bias={}, double eps=1e-5) -> Tensor
356+
- func: group_norm(Tensor input, int64_t num_groups, Tensor? weight={}, Tensor? bias={}, double eps=1e-5, bool cudnn_enabled=True) -> Tensor
357357
variants: function
358358

359359
# FFT
@@ -393,6 +393,9 @@
393393

394394
- func: is_sparse(Tensor self) -> bool
395395

396+
- func: layer_norm(Tensor input, IntList normalized_shape, Tensor? weight={}, Tensor? bias={}, double eps=1e-5, bool cudnn_enable=True) -> Tensor
397+
variants: function
398+
396399
- func: linspace(Type dtype, Scalar start, Scalar end, int64_t steps=100) -> Tensor
397400
variants: function
398401

test/test_nn.py

Lines changed: 15 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1759,24 +1759,17 @@ def _test_LayerNorm_general(self, type):
17591759
self.assertAlmostEqual(torch.abs(mean.data).mean(), bias, delta=1e-5)
17601760
self.assertAlmostEqual(torch.abs(var.data).mean(), scale ** 2, delta=1e-5)
17611761

1762-
# test that LN with track_running_stats=True
1763-
ln = nn.LayerNorm(normalized_shape, momentum=1, eps=0,
1764-
elementwise_affine=False, track_running_stats=True).type(type)
1765-
output_ref = ln(x).data.clone()
1766-
input_reshaped = x.view(*(unnormalized_shape + [-1]))
1767-
# make sure that running mean and var update correctly when training
1768-
mean = input_reshaped.mean(-1).mean()
1769-
var = input_reshaped.var(-1, unbiased=True).mean()
1770-
self.assertAlmostEqual(torch.abs(mean.data - ln.running_mean).mean(), 0, delta=1e-5)
1771-
self.assertAlmostEqual(torch.abs(var.data - ln.running_var).mean(), 0, delta=1e-5)
1772-
ln.eval()
1773-
old_running_mean = ln.running_mean.clone()
1774-
old_running_var = ln.running_var.clone()
1775-
output_new = ln(x + ln.running_var.sqrt()[0] * scale).data
1776-
self.assertAlmostEqual((output_new - output_ref).mean(), scale, delta=1e-5)
1777-
# make sure that running mean and var don't change in eval
1778-
self.assertEqual(old_running_mean, ln.running_mean)
1779-
self.assertEqual(old_running_var, ln.running_var)
1762+
bad_norm_shape_input_shape = {
1763+
(): (),
1764+
(2, 3): (3,),
1765+
(2,): (1, 2, 3),
1766+
(10,): (2, 3),
1767+
10: (2, 3),
1768+
}
1769+
for norm_shape, input_shape in bad_norm_shape_input_shape.items():
1770+
ln = nn.LayerNorm(norm_shape)
1771+
input = type(*input_shape).uniform_(0, 10)
1772+
self.assertRaises(RuntimeError, lambda: ln(input))
17801773

17811774
def _test_LayerNorm_cuda_half(self):
17821775
input = torch.zeros(2, 3, 3, 2, requires_grad=True).cuda().half().random_(1, 10)
@@ -5963,52 +5956,36 @@ def multimarginloss_weights_no_reduce_test():
59635956
),
59645957
dict(
59655958
module_name='LayerNorm',
5966-
constructor_args=([5], 1e-3, 0.3),
5959+
constructor_args=([5], 1e-3),
59675960
input_size=(4, 5, 5),
59685961
cudnn=True,
59695962
check_eval=True,
59705963
desc='1d_elementwise_affine',
59715964
),
59725965
dict(
59735966
module_name='LayerNorm',
5974-
constructor_args=([5], 1e-3, 0.3, False),
5967+
constructor_args=([5], 1e-3, False),
59755968
input_size=(4, 5, 5),
59765969
cudnn=True,
59775970
check_eval=True,
59785971
desc='1d_no_elementwise_affine',
59795972
),
59805973
dict(
59815974
module_name='LayerNorm',
5982-
constructor_args=([5], 1e-3, 0.3, True, True),
5983-
input_size=(4, 5, 5),
5984-
cudnn=True,
5985-
check_eval=True,
5986-
desc='1d_elementwise_affine_tracking_stats',
5987-
),
5988-
dict(
5989-
module_name='LayerNorm',
5990-
constructor_args=([2, 2, 5], 1e-3, 0.3),
5975+
constructor_args=([2, 2, 5], 1e-3),
59915976
input_size=(4, 2, 2, 5),
59925977
cudnn=True,
59935978
check_eval=True,
59945979
desc='3d_elementwise_affine',
59955980
),
59965981
dict(
59975982
module_name='LayerNorm',
5998-
constructor_args=([2, 2, 5], 1e-3, 0.3, False),
5983+
constructor_args=([2, 2, 5], 1e-3, False),
59995984
input_size=(4, 2, 2, 5),
60005985
cudnn=True,
60015986
check_eval=True,
60025987
desc='3d_no_elementwise_affine',
60035988
),
6004-
dict(
6005-
module_name='LayerNorm',
6006-
constructor_args=([2, 2, 5], 1e-3, 0.3, True, True),
6007-
input_size=(4, 2, 2, 5),
6008-
cudnn=True,
6009-
check_eval=True,
6010-
desc='3d_elementwise_affine_tracking_stats',
6011-
),
60125989
dict(
60135990
module_name='GroupNorm',
60145991
constructor_args=(3, 6, 1e-3),

torch/nn/functional.py

Lines changed: 5 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1257,74 +1257,22 @@ def _instance_norm(input, running_mean=None, running_var=None, weight=None,
12571257
eps=eps)
12581258

12591259

1260-
def layer_norm(input, normalized_shape, running_mean=None, running_var=None,
1261-
weight=None, bias=None, use_input_stats=True,
1262-
momentum=0.1, eps=1e-5):
1260+
def layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-5):
12631261
r"""Applies Layer Normalization for last certain number of dimensions.
12641262
12651263
See :class:`~torch.nn.LayerNorm` for details.
12661264
"""
1267-
if not use_input_stats and (running_mean is None or running_var is None):
1268-
raise ValueError('Expected running_mean and running_var to be not None when use_input_stats=False')
1269-
1270-
if weight is not None and weight.size() != normalized_shape:
1271-
raise ValueError('Expected weight to be of same shape as '
1272-
'normalized_shape, but got {} weight and '
1273-
'normalized_shape={}'.format(weight.size(), normalized_shape))
1274-
1275-
if bias is not None and bias.size() != normalized_shape:
1276-
raise ValueError('Expected bias to be of same shape as '
1277-
'normalized_shape, but got {} bias and '
1278-
'normalized_shape={}'.format(bias.size(), normalized_shape))
1279-
1280-
normalized_ndim = len(normalized_shape)
1281-
input_shape = input.size()
1282-
1283-
if input_shape[-normalized_ndim:] != torch.Size(normalized_shape):
1284-
raise ValueError('Expected input with shape [*, {}], but got {} input'
1285-
.format(', '.join(normalized_shape), list(input_shape)))
1286-
1287-
n = reduce(mul, input_shape[:-normalized_ndim], 1)
1288-
1289-
# Repeat stored stats if necessary
1290-
if running_mean is not None:
1291-
running_mean_orig = running_mean
1292-
running_mean = running_mean_orig.repeat(n)
1293-
if running_var is not None:
1294-
running_var_orig = running_var
1295-
running_var = running_var_orig.repeat(n)
1296-
1297-
# Apply layer norm
1298-
input_reshaped = input.contiguous().view(1, n, -1)
1299-
1300-
out = batch_norm(
1301-
input_reshaped, running_mean, running_var, None, None,
1302-
use_input_stats, momentum, eps)
1303-
1304-
# Copy back
1305-
if running_mean is not None:
1306-
running_mean_orig.fill_(running_mean.mean())
1307-
if running_var is not None:
1308-
running_var_orig.fill_(running_var.mean())
1309-
1310-
out = out.view(*input_shape)
1311-
1312-
if weight is not None and bias is not None:
1313-
return torch.addcmul(bias, 1, out, weight)
1314-
elif weight is not None:
1315-
return torch.mul(out, weight)
1316-
elif bias is not None:
1317-
return torch.add(out, bias)
1318-
else:
1319-
return out
1265+
return torch.layer_norm(input, normalized_shape, weight, bias, eps,
1266+
torch.backends.cudnn.enabled)
13201267

13211268

13221269
def group_norm(input, num_groups, weight=None, bias=None, eps=1e-5):
13231270
r"""Applies Group Normalization for last certain number of dimensions.
13241271
13251272
See :class:`~torch.nn.GroupNorm` for details.
13261273
"""
1327-
return torch.group_norm(input, num_groups, weight, bias, eps)
1274+
return torch.group_norm(input, num_groups, weight, bias, eps,
1275+
torch.backends.cudnn.enabled)
13281276

13291277

13301278
def local_response_norm(input, size, alpha=1e-4, beta=0.75, k=1):

0 commit comments

Comments
 (0)