Skip to content

Commit 17941f9

Browse files
t-vifacebook-github-bot
authored andcommitted
JIT: Eliminate SumToSize by using Optional Lists (#18697)
Summary: This PR is a eliminates unneeded grad_sum_to_size and in particular speeds up the LSTM backward by allowing better fusion. It consists of two parts: - In AutoDiff, record broadcasting sizes only if the broadcast output size is different from the input size, otherwise record None. - The specialization of Optional arguments (#18407) allows us to then eliminate ` _grad_sum_to_size(t, None)` in the peephole optimization step. Thus, in the LSTM case, no SumToSize remain in the crucial fusion group. The trick here is that we can specialize on the runtime information from the forward. I'm testing that different broadcasting situations lead to different graphs. I didn't move all symbolic_script _grad_sum_to_size to the new logic, but it might be better to do this incrementally, anyway. Pull Request resolved: #18697 Differential Revision: D15482076 Pulled By: wanchaol fbshipit-source-id: 7f89367e35b8729910077c95c02bccefc8678afb
1 parent 4704322 commit 17941f9

File tree

10 files changed

+174
-52
lines changed

10 files changed

+174
-52
lines changed

aten/src/ATen/core/interned_strings.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ namespace c10 {
8383
_(prim, abs) \
8484
_(prim, rangelist) \
8585
_(aten, _grad_sum_to_size) \
86+
_(aten, _size_if_not_equal) \
8687
_(aten, _ncf_unsqueeze) \
8788
_(aten, warn) \
8889
_(aten, floordiv) \

test/cpp/jit/test_autodiff.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ void testDifferentiate() {
182182

183183
auto grad_spec = differentiate(graph);
184184
std::vector<size_t> expected_captured_inputs = {0, 1};
185-
std::vector<size_t> expected_captured_outputs = {1, 2};
185+
std::vector<size_t> expected_captured_outputs = {1, 2, 3, 4, 5, 6, 7};
186186
std::vector<size_t> expected_input_vjps = {0, 1};
187187
std::vector<size_t> expected_output_vjps = {0, 1};
188188
ASSERT_EQ(grad_spec.f_real_outputs, 1);
@@ -228,7 +228,9 @@ void testDifferentiateWithRequiresGrad() {
228228
std::vector<size_t> expected_output_vjps = {0}; // only a requires grad
229229
ASSERT_EQ(grad_spec.f_real_outputs, 2);
230230
ASSERT_EQ(grad_spec.df_input_captured_inputs, std::vector<size_t>({0}));
231-
ASSERT_EQ(grad_spec.df_input_captured_outputs, std::vector<size_t>({2, 3}));
231+
ASSERT_EQ(
232+
grad_spec.df_input_captured_outputs,
233+
std::vector<size_t>({2, 3, 4, 5, 6}));
232234
ASSERT_EQ(grad_spec.df_input_vjps, expected_input_vjps);
233235
ASSERT_EQ(grad_spec.df_output_vjps, expected_output_vjps);
234236
testing::FileCheck()

test/test_jit.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,15 @@ def get_grad_executor(plan_state, diff_graph_idx=None):
200200
return grad_executors[diff_graph_idx or 0]
201201

202202

203+
def all_backward_graphs(script_module, diff_graph_idx=None):
204+
# Note: for Python 2 the order seems to be unstable
205+
ge_state = script_module.get_debug_state()
206+
fwd_plan = get_execution_plan(ge_state)
207+
grad_executor_state = get_grad_executor(fwd_plan, diff_graph_idx=diff_graph_idx)
208+
bwd_plans = list(grad_executor_state.execution_plans.values())
209+
return [p.graph.copy() for p in bwd_plans]
210+
211+
203212
def backward_graph(script_module, diff_graph_idx=None):
204213
ge_state = script_module.get_debug_state()
205214
fwd_plan = get_execution_plan(ge_state)

test/test_jit_fuser.py

Lines changed: 53 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from itertools import product, permutations
1616

1717
from test_jit import JitTestCase, enable_cpu_fuser, RUN_CUDA, RUN_CUDA_HALF, RUN_CUDA_MULTI_GPU, \
18-
backward_graph, get_lstm_inputs, get_milstm_inputs, LSTMCellC, LSTMCellF, LSTMCellS, MiLSTMCell
18+
backward_graph, all_backward_graphs, get_lstm_inputs, get_milstm_inputs, LSTMCellC, LSTMCellF, LSTMCellS, MiLSTMCell
1919

2020

2121
class TestFuser(JitTestCase):
@@ -275,7 +275,7 @@ def funcOptMax(a, b):
275275
for f, inputs in product(funcs, [[a, b], [a, nan]]):
276276
inp1, inp2 = inputs
277277
s = self.checkScript(f, (inp1, inp2))
278-
self.assertAllFused(s.graph_for(inp1, inp2), except_for={'aten::size'})
278+
self.assertAllFused(s.graph_for(inp1, inp2), except_for={'aten::size', 'aten::_size_if_not_equal'})
279279

280280
c = s(inp1, inp2)
281281
c.sum().backward()
@@ -350,7 +350,8 @@ def f(x, y):
350350
self.assertAllFused(ge.graph_for(x, y))
351351
x.requires_grad_(True)
352352
y.requires_grad_(True)
353-
self.assertAllFused(ge.graph_for(x, y), except_for=("aten::size", "prim::BroadcastSizes"))
353+
self.assertAllFused(ge.graph_for(x, y), except_for=("aten::size", "prim::BroadcastSizes",
354+
"aten::_size_if_not_equal"))
354355

355356
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
356357
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
@@ -522,7 +523,8 @@ def fn_test_scalar_arg(x, p):
522523
self.assertAllFused(scripted.graph_for(x, p))
523524
x.requires_grad_(True)
524525
out = scripted(x, p)
525-
self.assertAllFused(scripted.graph_for(x, p), except_for=("aten::size", "prim::BroadcastSizes"))
526+
self.assertAllFused(scripted.graph_for(x, p), except_for=("aten::size", "prim::BroadcastSizes",
527+
"aten::_size_if_not_equal"))
526528

527529
@unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
528530
@enable_cpu_fuser
@@ -535,7 +537,7 @@ def f(x, y):
535537
b = torch.randn(5, 5, requires_grad=True)
536538
a = torch.randn(5, 5, requires_grad=True)
537539
s = self.checkScript(f, (a, b))
538-
self.assertAllFused(s.graph_for(a, b), except_for={'aten::size'})
540+
self.assertAllFused(s.graph_for(a, b), except_for={'aten::size', 'aten::_size_if_not_equal', 'prim::BroadcastSizes'})
539541

540542
c = s(a, b)
541543
ga, gb = torch.autograd.grad(c.sum(), [a, b])
@@ -578,12 +580,12 @@ def iou(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2):
578580

579581
s = self.checkScript(iou, (b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2))
580582
self.assertAllFused(s.graph_for(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2),
581-
except_for={'aten::size', 'prim::BroadcastSizes'})
583+
except_for={'aten::size', 'prim::BroadcastSizes', 'aten::_size_if_not_equal'})
582584

583585
c = s(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2)
584586
torch.autograd.grad(c.sum(), [b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2])
585587
graph = backward_graph(s)
586-
self.assertAllFused(graph, except_for={'aten::size', 'prim::BroadcastSizes'})
588+
self.assertAllFused(graph, except_for={'aten::size', 'prim::BroadcastSizes', 'aten::_size_if_not_equal'})
587589

588590
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
589591
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
@@ -670,8 +672,8 @@ def test_lstm_cuda(self):
670672
hy, cy = module(*inputs)
671673
(hy + cy).sum().backward()
672674
backward = backward_graph(module)
673-
FileCheck().check("FusionGroup_0").check_next("FusionGroup_1") \
674-
.check_not("FusionGroup_2").run(str(backward))
675+
self.assertAllFused(backward, except_for=("aten::t", "aten::mm",
676+
"aten::_grad_sum_to_size"))
675677

676678
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
677679
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
@@ -801,7 +803,8 @@ def fn_test_erf(x):
801803
ge = self.checkTrace(fn_test_erf, (x,))
802804
self.assertAllFused(ge.graph_for(x))
803805
x.requires_grad_(True)
804-
self.assertAllFused(ge.graph_for(x), except_for=("aten::size", "prim::BroadcastSizes"))
806+
self.assertAllFused(ge.graph_for(x), except_for=("aten::size", "prim::BroadcastSizes",
807+
"aten::_size_if_not_equal"))
805808

806809
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
807810
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
@@ -818,7 +821,8 @@ def fn_test_rand(x, y):
818821
self.assertAllFused(script_f.graph_for(x, y))
819822
x.requires_grad_(True)
820823
out = script_f(x, y)
821-
self.assertAllFused(script_f.graph_for(x, y), except_for=("aten::size", "prim::BroadcastSizes"))
824+
self.assertAllFused(script_f.graph_for(x, y), except_for=("aten::size", "prim::BroadcastSizes",
825+
"aten::_size_if_not_equal"))
822826
# test that broadcasting random produces correct results
823827
x = torch.ones(4, 4, dtype=torch.float, device='cuda')
824828
y = torch.ones(4, dtype=torch.float, device='cuda')
@@ -894,6 +898,44 @@ def f(x, y):
894898
self.assertEqual(result2, expected2)
895899
self.assertAllFused(script_f.graph_for(x, y), except_for={'prim::TupleConstruct'})
896900

901+
902+
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
903+
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
904+
@skipIfRocm
905+
def test_grad_sum_to_size_elimination(self):
906+
907+
def my_broadcasted_cell(a, b, c):
908+
return (a + b) + c
909+
910+
s1 = torch.randn(5, 1, requires_grad=True, device='cuda')
911+
s2 = torch.randn(5, 5, requires_grad=True, device='cuda')
912+
913+
module = self.checkScript(my_broadcasted_cell, (s1, s1, s1))
914+
forward_graph = module.graph_for(s1, s1, s1)
915+
self.assertAllFused(forward_graph, except_for=("aten::size", "prim::BroadcastSizes",
916+
"aten::_size_if_not_equal"))
917+
918+
old_plans = set()
919+
for i in range(3):
920+
# if we have s2, then the s1 are _grad_sum_to_size'd
921+
args = s2 if i < 1 else s1, s2 if i < 2 else s1, s2
922+
args = [a.detach_().requires_grad_() for a in args]
923+
res = module(s2 if i < 1 else s1, s2 if i < 2 else s1, s2)
924+
grads = torch.autograd.grad(res.sum(), args)
925+
for inp, gr in zip(args, grads):
926+
self.assertEqual(inp.shape, gr.shape)
927+
backward = None
928+
# this is a workaround for the backward graphs not being
929+
# in order for Python 2
930+
for g in all_backward_graphs(module):
931+
if str(g) not in old_plans:
932+
assert backward is None
933+
backward = g
934+
old_plans.add(str(backward))
935+
self.assertEqual(len([1 for o in backward.outputs() if o.node().kind() == "aten::_grad_sum_to_size"]), i)
936+
self.assertEqual(len([1 for o in backward.outputs() if o.node().kind() == "prim::Param"]), 3 - i)
937+
938+
897939
@unittest.skipIf(not IS_WINDOWS, "Test that the fuser is disabled on Windows")
898940
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
899941
def test_windows_cuda(self):

torch/csrc/jit/autodiff.cpp

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -213,11 +213,17 @@ class GradientHelper {
213213
private:
214214
Node* node;
215215

216-
SymbolicVariable gradSumToSizeOf(SymbolicVariable v, Symbol input_name) {
216+
SymbolicVariable gradSumToSizeOf(
217+
SymbolicVariable v,
218+
Symbol input_name,
219+
SymbolicVariable fw_output) {
217220
Value* size;
218221
{
219-
WithInsertPoint insert_guard{node};
220-
size = SymbolicVariable(node->namedInput(input_name)).size();
222+
// We insert after the current node because we want to use
223+
// its output.
224+
WithInsertPoint insert_guard{node->next()};
225+
size = SymbolicVariable(node->namedInput(input_name))
226+
.size_if_not_equal(fw_output);
221227
}
222228
return v.gradSumToSize(size);
223229
};
@@ -237,9 +243,11 @@ class GradientHelper {
237243

238244
if (node->matches(
239245
"aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) {
240-
return {gradSumToSizeOf(grads.at(0), attr::self),
246+
return {gradSumToSizeOf(grads.at(0), attr::self, outputs.at(0)),
241247
gradSumToSizeOf(
242-
grads.at(0) * node->namedInput(attr::alpha), attr::other),
248+
grads.at(0) * node->namedInput(attr::alpha),
249+
attr::other,
250+
outputs.at(0)),
243251
nullptr};
244252

245253
} else if (
@@ -254,9 +262,11 @@ class GradientHelper {
254262
} else if (
255263
node->matches(
256264
"aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) {
257-
return {gradSumToSizeOf(grads.at(0), attr::self),
265+
return {gradSumToSizeOf(grads.at(0), attr::self, outputs.at(0)),
258266
gradSumToSizeOf(
259-
-grads.at(0) * node->namedInput(attr::alpha), attr::other),
267+
-grads.at(0) * node->namedInput(attr::alpha),
268+
attr::other,
269+
outputs.at(0)),
260270
nullptr};
261271

262272
} else if (
@@ -337,7 +347,9 @@ class GradientHelper {
337347
node->matches(
338348
"aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta, Scalar alpha) -> Tensor")) {
339349
return {gradSumToSizeOf(
340-
grads.at(0) * node->namedInput(attr::beta), attr::self),
350+
grads.at(0) * node->namedInput(attr::beta),
351+
attr::self,
352+
outputs.at(0)),
341353
grads.at(0).mm(inputs.at(2).t()) * node->namedInput(attr::alpha),
342354
inputs.at(1).t().mm(grads.at(0)) * node->namedInput(attr::alpha),
343355
nullptr,

torch/csrc/jit/passes/graph_fuser.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -813,6 +813,10 @@ struct GraphFuser {
813813
// The output of producer_for_chunk_node could have been used in some
814814
// aten::size operators, so we need to clean those up as well (we simply
815815
// broadcast all its tensor inputs).
816+
// We need to insert these early in the graph, i.e. immediately after
817+
// the producer_for_chunk_node as we will have the _size_if_not_same
818+
// that may be before the bchunk.
819+
WithInsertPoint guard2(producer_for_chunk_node);
816820
auto size_calc_uses = producer_for_chunk_node->output()->uses();
817821
if (!size_calc_uses.empty()) {
818822
auto tensor_inputs = filter(

torch/csrc/jit/passes/peephole.cpp

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -157,12 +157,17 @@ void PeepholeOptimizeImpl(Block* block, bool addmm_fusion_enabled) {
157157
}
158158
} else if (
159159
node->matches(
160-
"aten::_grad_sum_to_size(Tensor(a) self, int[] size) -> Tensor(a)")) {
161-
auto uses = node->output()->uses();
162-
for (Use u : uses) {
163-
if (u.user->matches(
164-
"aten::_grad_sum_to_size(Tensor(a) self, int[] size) -> Tensor(a)")) {
165-
u.user->replaceInput(0, node->inputs().at(0));
160+
"aten::_grad_sum_to_size(Tensor(a) self, int[]? size) -> Tensor(a)")) {
161+
if (node->input(1)->mustBeNone()) {
162+
node->output()->replaceAllUsesWith(node->input(0));
163+
} else {
164+
auto uses = node->output()->uses();
165+
for (Use u : uses) {
166+
if (u.user->matches(
167+
"aten::_grad_sum_to_size(Tensor(a) self, int[]? size) -> Tensor(a)") &&
168+
u.user->input(1)->type()->isSubtypeOf(ListType::ofInts())) {
169+
u.user->replaceInput(0, node->inputs().at(0));
170+
}
166171
}
167172
}
168173
} else if (node->kind() == prim::If) {

torch/csrc/jit/register_prim_ops.cpp

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -655,12 +655,30 @@ RegisterOperators reg(
655655
};
656656
}),
657657
Operator(
658-
"aten::_grad_sum_to_size(Tensor(a) self, int[] size) -> Tensor(a)",
658+
"aten::_grad_sum_to_size(Tensor(a) self, int[]? size) -> Tensor(a)",
659659
[](Stack& stack) {
660-
at::Tensor self;
661-
Shared<IntList> desired_sizes;
662-
pop(stack, self, desired_sizes);
663-
push(stack, at::sum_to(std::move(self), desired_sizes->elements()));
660+
IValue self, size;
661+
pop(stack, self, size);
662+
if (size.isNone()) {
663+
push(stack, self);
664+
} else {
665+
push(
666+
stack,
667+
at::sum_to(self.toTensor(), size.toIntList()->elements()));
668+
}
669+
return 0;
670+
}),
671+
Operator(
672+
"aten::_size_if_not_equal(int[] self_size, int[] other_size) -> int[]?",
673+
[](Stack& stack) {
674+
IValue self_size, other_size;
675+
pop(stack, self_size, other_size);
676+
const auto s = self_size.toIntList()->elements();
677+
if (s == other_size.toIntList()->elements()) {
678+
push(stack, IValue());
679+
} else {
680+
push(stack, s);
681+
}
664682
return 0;
665683
}),
666684
Operator(

0 commit comments

Comments
 (0)