Skip to content

Commit ecbb12a

Browse files
author
mikeiovine
committed
[SR] Avoid allocating rstd/mean in layer_norm
Pull Request resolved: #73606 The single-output overload of `layer_norm` internally allocates two tensors. As an optimization, we previously added `static_runtime::layer_norm`. This variant of layer norm had two extra outputs to make the memory planner aware of these extra tensors. But these outputs were unused; it's actually better for us to avoid the allocation and associated computations entirely. ghstack-source-id: 151281719 Differential Revision: [D34562131](https://our.internmc.facebook.com/intern/diff/D34562131/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D34562131/)!
1 parent 0b1f3bd commit ecbb12a

File tree

8 files changed

+36
-82
lines changed

8 files changed

+36
-82
lines changed

aten/src/ATen/native/cpu/layer_norm_kernel.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,13 @@ void LayerNormKernelImplInternal(
4242
const T* gamma_data = gamma.defined() ? gamma.data_ptr<T>() : nullptr;
4343
const T* beta_data = beta.defined() ? beta.data_ptr<T>() : nullptr;
4444
T* Y_data = Y->data_ptr<T>();
45-
T* mean_data = mean->data_ptr<T>();
46-
T* rstd_data = rstd->data_ptr<T>();
45+
T* mean_data = mean ? mean->data_ptr<T>() : nullptr;
46+
T* rstd_data = rstd ? rstd->data_ptr<T>() : nullptr;
47+
4748
const bool gamma_null = gamma_data == nullptr;
4849
const bool beta_null = beta_data == nullptr;
50+
const bool mean_null = mean_data == nullptr;
51+
const bool rstd_null = rstd_data == nullptr;
4952
at::parallel_for(0, M, 1, [&](int64_t start, int64_t end) {
5053
for (const auto i : c10::irange(start, end)) {
5154
const T* X_ptr = X_data + i * N;
@@ -73,8 +76,12 @@ void LayerNormKernelImplInternal(
7376
beta_data,
7477
N);
7578
}
76-
mean_data[i] = mean_val;
77-
rstd_data[i] = rstd_val;
79+
if (!mean_null) {
80+
mean_data[i] = mean_val;
81+
}
82+
if (!rstd_null) {
83+
rstd_data[i] = rstd_val;
84+
}
7885
}
7986
});
8087
}

aten/src/ATen/native/layer_norm.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
namespace at {
1919
namespace native {
2020

21-
void layer_norm_cpu_out(
21+
void layer_norm_with_mean_rstd_out(
2222
at::Tensor& out,
2323
at::Tensor& mean,
2424
at::Tensor& rstd,
@@ -50,6 +50,20 @@ void layer_norm_cpu_out(
5050
rstd = rstd.view(stat_shape);
5151
}
5252

53+
void layer_norm_cpu_out(
54+
at::Tensor& out,
55+
const at::Tensor& input,
56+
const Tensor& gamma,
57+
const Tensor& beta,
58+
double eps,
59+
int64_t M,
60+
int64_t N) {
61+
if (M <= 0) {
62+
return;
63+
}
64+
LayerNormKernel(kCPU, input, gamma, beta, M, N, eps, &out, /*mean=*/nullptr, /*rstd=*/nullptr);
65+
}
66+
5367
std::tuple<Tensor, Tensor, Tensor> layer_norm_cpu(
5468
const Tensor& input,
5569
IntArrayRef normalized_shape, const c10::optional<Tensor>& weight_opt /* optional */, const c10::optional<Tensor>& bias_opt /* optional */,
@@ -78,7 +92,7 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_cpu(
7892
Tensor mean = at::empty({M}, X->options());
7993
Tensor rstd = at::empty({M}, X->options());
8094

81-
layer_norm_cpu_out(Y, mean, rstd, *X, normalized_shape, *gamma, *beta, eps, M, N);
95+
layer_norm_with_mean_rstd_out(Y, mean, rstd, *X, normalized_shape, *gamma, *beta, eps, M, N);
8296
return std::make_tuple(std::move(Y), std::move(mean), std::move(rstd));
8397
}
8498

aten/src/ATen/native/layer_norm.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,7 @@ C10_ALWAYS_INLINE std::pair<int64_t, int64_t> _check_layer_norm_inputs(
6565

6666
void layer_norm_cpu_out(
6767
at::Tensor& out,
68-
at::Tensor& mean,
69-
at::Tensor& rstd,
7068
const at::Tensor& input,
71-
IntArrayRef normalized_shape,
7269
const Tensor& gamma,
7370
const Tensor& beta,
7471
double eps,

benchmarks/static_runtime/test_static_runtime.cc

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -406,13 +406,6 @@ TEST(StaticRuntime, LayerNorm) {
406406
return torch.layer_norm(input, normalized_shape, None, None, 1e-05, False).clone()
407407
)JIT";
408408

409-
#ifdef FBCODE_CAFFE2
410-
script::Module module("module");
411-
module.define(layer_norm_with_weights);
412-
torch::jit::StaticModule smodule(module);
413-
ASSERT_EQ(getNodeWithKind(smodule, "aten::layer_norm"), nullptr);
414-
ASSERT_NE(getNodeWithKind(smodule, "static_runtime::layer_norm"), nullptr);
415-
#endif
416409
const auto a = torch::rand({1, 2, 2, 2});
417410
const auto b = torch::rand({3, 2, 2, 2});
418411
for (int normalized_size : {2, 3}) {

torch/csrc/jit/runtime/static/impl.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,6 @@ void OptimizeGraph(
164164
ReplaceWithMaybeCopy(graph);
165165
}
166166
FuseListUnpack(graph);
167-
EnableStaticRuntimeLayerNorm(graph);
168167
#endif
169168
}
170169

torch/csrc/jit/runtime/static/ops.cpp

Lines changed: 9 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1901,22 +1901,22 @@ REGISTER_OPERATOR_FUNCTOR(aten::softmax, aten_softmax, [](Node* n) -> SROperator
19011901
};
19021902
});
19031903

1904-
static c10::MaybeOwned<at::Tensor> borrow_from_optional_tensor_ivalue(
1904+
namespace {
1905+
1906+
c10::MaybeOwned<at::Tensor> borrow_from_optional_tensor_ivalue(
19051907
const IValue& iv) {
19061908
if (iv.isNone()) {
19071909
return c10::MaybeOwned<at::Tensor>::owned(c10::in_place);
19081910
}
19091911
return c10::MaybeOwned<at::Tensor>::borrowed(iv.toTensor());
19101912
}
1913+
1914+
} // namespace
1915+
19111916
REGISTER_OPERATOR_FUNCTOR(
1912-
static_runtime::layer_norm,
1917+
aten::layer_norm,
19131918
aten_layer_norm,
1914-
[](Node* n) -> SROperator {
1915-
if (!n->matches(torch::schema(
1916-
"static_runtime::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> (Tensor,Tensor,Tensor)"))) {
1917-
LogAndDumpSchema(n);
1918-
return nullptr;
1919-
}
1919+
[](Node*) -> SROperator {
19201920
return [](ProcessedNode* p_node) {
19211921
// ignore Input(5): `bool cudnn_enable=True`
19221922
const auto& input = p_node->Input(0).toTensor();
@@ -1950,30 +1950,8 @@ REGISTER_OPERATOR_FUNCTOR(
19501950
at::native::resize_(
19511951
p_node->Output(0).toTensor(), X->sizes(), c10::nullopt);
19521952
}
1953-
if (p_node->Output(1).isNone()) {
1954-
p_node->Output(1) = create_empty_from({M}, *X);
1955-
} else {
1956-
at::native::resize_(p_node->Output(1).toTensor(), {M}, c10::nullopt);
1957-
}
1958-
if (p_node->Output(2).isNone()) {
1959-
p_node->Output(2) = create_empty_from({M}, *X);
1960-
} else {
1961-
at::native::resize_(p_node->Output(2).toTensor(), {M}, c10::nullopt);
1962-
}
19631953
at::Tensor& output = p_node->Output(0).toTensor();
1964-
at::Tensor& mean = p_node->Output(1).toTensor();
1965-
at::Tensor& rstd = p_node->Output(2).toTensor();
1966-
at::native::layer_norm_cpu_out(
1967-
output,
1968-
mean,
1969-
rstd,
1970-
input,
1971-
normalized_shape,
1972-
*gamma,
1973-
*beta,
1974-
eps,
1975-
M,
1976-
N);
1954+
at::native::layer_norm_cpu_out(output, input, *gamma, *beta, eps, M, N);
19771955
};
19781956
});
19791957

torch/csrc/jit/runtime/static/passes.cpp

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -832,37 +832,6 @@ void FuseListUnpack(std::shared_ptr<torch::jit::Graph>& graph) {
832832
}
833833
} // namespace jit
834834

835-
void EnableStaticRuntimeLayerNorm(std::shared_ptr<torch::jit::Graph>& graph) {
836-
const c10::Symbol static_runtime_layer_norm_symbol =
837-
fromQualString("static_runtime::layer_norm");
838-
auto nodes = graph->nodes();
839-
std::vector<std::pair<Node*, Node*>> replacement;
840-
DepthFirstGraphNodeIterator graph_it(graph);
841-
for (auto old_node = graph_it.next(); old_node != nullptr;
842-
old_node = graph_it.next()) {
843-
if (!old_node->matches(torch::schema(
844-
"aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor"))) {
845-
continue;
846-
}
847-
TORCH_CHECK(old_node->outputs().size() == 1);
848-
auto* new_node = graph->create(
849-
static_runtime_layer_norm_symbol,
850-
/*layer_norm*/ 1 + /*mean*/ 1 + /*rst=*/1);
851-
for (auto* input : old_node->inputs()) {
852-
new_node->addInput(input);
853-
}
854-
replacement.emplace_back(old_node, new_node);
855-
}
856-
for (const auto& p : replacement) {
857-
auto* old_node = p.first;
858-
auto* new_node = p.second;
859-
new_node->insertBefore(old_node);
860-
new_node->output(0)->copyMetadata(old_node->output(0));
861-
old_node->output(0)->replaceAllUsesWith(new_node->output(0));
862-
old_node->destroy();
863-
}
864-
}
865-
866835
void RemoveImmutableInputDictLookups(
867836
std::shared_ptr<torch::jit::Graph>& graph) {
868837
auto nodes = graph->nodes();

torch/csrc/jit/runtime/static/passes.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,6 @@ void ReplaceWithMaybeCopy(
2121
std::shared_ptr<torch::jit::Graph>& graph,
2222
bool outputs_are_immutable = true);
2323

24-
TORCH_API void EnableStaticRuntimeLayerNorm(
25-
std::shared_ptr<torch::jit::Graph>& graph);
26-
2724
TORCH_API void RemoveImmutableInputDictLookups(
2825
std::shared_ptr<torch::jit::Graph>& graph);
2926

0 commit comments

Comments
 (0)