Skip to content

Commit 66fe3b5

Browse files
Runtian Zhoufacebook-github-bot
authored andcommitted
Add peephole optimization for type_as operators. (#9316)
Summary: If the type_as operator takes in two values with the same type, remove that operator. Pull Request resolved: #9316 Reviewed By: zdevito Differential Revision: D8808355 fbshipit-source-id: 2d5710a6380b22f4568fc38a439061b5340c4eb1
1 parent 52abcdd commit 66fe3b5

File tree

5 files changed

+70
-0
lines changed

5 files changed

+70
-0
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
graph(%0 : Double(1)
2+
%1 : Double(1)) {
3+
return (%0);
4+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
graph(%0 : Double(1)
2+
%1 : Double(1)) {
3+
%2 : Double(1) = aten::type_as(%0, %1)
4+
return (%2);
5+
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
graph(%0 : Double(1)
2+
%1 : Double(1)) {
3+
return (%0);
4+
}

test/test_jit.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,48 @@ def f(x, y):
256256
self.assertExpectedGraph(trace)
257257
self.assertExportImport(trace, (x, y))
258258

259+
def test_peephole(self):
260+
a = torch.tensor([0.4], requires_grad=True)
261+
b = torch.tensor([0.7], requires_grad=True)
262+
c = torch.tensor([0], dtype=torch.int32)
263+
264+
def f(x, y):
265+
return x.type_as(y)
266+
267+
trace, z = torch.jit.get_trace_graph(f, (a, b))
268+
self.run_pass('peephole', trace)
269+
self.assertExpectedGraph(trace)
270+
trace, z = torch.jit.get_trace_graph(f, (a, c))
271+
s = str(trace)
272+
self.run_pass('peephole', trace)
273+
self.assertEqual(s, str(trace))
274+
275+
def test_peephole_dynamic(self):
276+
def f(x, y):
277+
return x.type_as(y)
278+
279+
fn = torch.jit.script(f)
280+
s = str(fn.graph)
281+
torch._C._jit_pass_peephole(fn.graph)
282+
self.assertEqual(s, str(fn.graph))
283+
284+
@unittest.skipIf(not RUN_CUDA, "cpp tests require CUDA")
285+
def test_peephole_cuda(self):
286+
a = torch.tensor([0.4], requires_grad=True, device='cpu')
287+
b = torch.tensor([0.7], requires_grad=True, device='cuda')
288+
c = torch.tensor([0.7], requires_grad=True, device='cuda')
289+
290+
def f(x, y):
291+
return x.type_as(y)
292+
293+
trace, z = torch.jit.get_trace_graph(f, (a, c))
294+
s = str(trace)
295+
self.run_pass('peephole', trace)
296+
self.assertEqual(s, str(trace))
297+
trace, z = torch.jit.get_trace_graph(f, (b, c))
298+
self.run_pass('peephole', trace)
299+
self.assertExpectedGraph(trace, subname="same_device")
300+
259301
def test_index(self):
260302
x = torch.tensor([0.4], requires_grad=True)
261303
y = torch.tensor([0], dtype=torch.int64)

torch/csrc/jit/passes/peephole.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,21 @@ void PeepholeOptimize(Block * block) {
4141
// Let DCE clean up any unused nodes at this point
4242
}
4343
} break;
44+
case aten::type_as: {
45+
JIT_ASSERT(n->inputs().size() == 2);
46+
Value *lhs = n->input(0);
47+
Value *rhs = n->input(1);
48+
// If LHS and RHS have the same static type, remove the type_as operator.
49+
if (lhs->type()->kind() == TypeKind::TensorType &&
50+
rhs->type()->kind() == TypeKind::TensorType) {
51+
auto ltype = (*lhs->type()).cast<TensorType>();
52+
auto rtype = (*rhs->type()).cast<TensorType>();
53+
if(ltype->device() == rtype->device() &&
54+
ltype->scalarType() == rtype->scalarType()) {
55+
n->output()->replaceAllUsesWith(lhs);
56+
}
57+
}
58+
} break;
4459
// Fuse mm + add into addmm
4560
case aten::add: {
4661
// Must have two inputs

0 commit comments

Comments
 (0)