Skip to content

Commit 63585c3

Browse files
bddppqfacebook-github-bot
authored andcommitted
Add support for save and load mkldnn modules
Summary: Pull Request resolved: #20799 Reviewed By: wanchaol Differential Revision: D15447891 fbshipit-source-id: e34de946c79282fb934a5c52ff1def41c7993c75
1 parent 5f83c5d commit 63585c3

File tree

4 files changed

+181
-58
lines changed

4 files changed

+181
-58
lines changed

test/common_utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,12 @@
2020
import socket
2121
import time
2222
from collections import OrderedDict
23+
from contextlib import contextmanager
2324
from functools import wraps
2425
from itertools import product
2526
from copy import deepcopy
2627
from numbers import Number
28+
import tempfile
2729

2830
import __main__
2931
import errno
@@ -66,6 +68,24 @@ def run_tests(argv=UNITTEST_ARGS):
6668
# Environment variable `IS_PYTORCH_CI` is set in `.jenkins/common.sh`.
6769
IS_PYTORCH_CI = bool(os.environ.get('IS_PYTORCH_CI', 0))
6870

71+
if IS_WINDOWS:
72+
@contextmanager
73+
def TemporaryFileName():
74+
# Ideally we would like to not have to manually delete the file, but NamedTemporaryFile
75+
# opens the file, and it cannot be opened multiple times in Windows. To support Windows,
76+
# close the file after creation and try to remove it manually
77+
f = tempfile.NamedTemporaryFile(delete=False)
78+
try:
79+
f.close()
80+
yield f.name
81+
finally:
82+
os.unlink(f.name)
83+
else:
84+
@contextmanager # noqa: T484
85+
def TemporaryFileName():
86+
with tempfile.NamedTemporaryFile() as f:
87+
yield f.name
88+
6989

7090
def _check_module_exists(name):
7191
r"""Returns if a top-level module with :attr:`name` exists *without**

test/test_jit.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from torch._six import inf, PY2, builtins, StringIO
1818
from common_utils import TestCase, run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \
1919
skipIfRocm, skipIfNoLapack, suppress_warnings, load_tests, IS_SANDCASTLE, \
20-
freeze_rng_state, set_rng_seed, slowTest
20+
freeze_rng_state, set_rng_seed, slowTest, TemporaryFileName
2121
from common_nn import module_tests, new_module_tests, criterion_tests
2222
from textwrap import dedent
2323
from functools import wraps, reduce
@@ -84,25 +84,6 @@
8484
WINDOWS = sys.platform == 'win32'
8585

8686

87-
if WINDOWS:
88-
@contextmanager
89-
def TemporaryFileName():
90-
# Ideally we would like to not have to manually delete the file, but NamedTemporaryFile
91-
# opens the file, and it cannot be opened multiple times in Windows. To support Windows,
92-
# close the file after creation and try to remove it manually
93-
f = tempfile.NamedTemporaryFile(delete=False)
94-
try:
95-
f.close()
96-
yield f.name
97-
finally:
98-
os.unlink(f.name)
99-
else:
100-
@contextmanager # noqa: T484
101-
def TemporaryFileName():
102-
with tempfile.NamedTemporaryFile() as f:
103-
yield f.name
104-
105-
10687
def LSTMCellF(input, hx, cx, *params):
10788
return LSTMCell(input, (hx, cx), *params)
10889

test/test_mkldnn.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
import unittest
44

55
import torch
6+
import torch.jit
67
from torch.utils import mkldnn as mkldnn_utils
7-
from common_utils import TestCase, run_tests
8+
from common_utils import TestCase, run_tests, TemporaryFileName
9+
810
from torch.autograd.gradcheck import gradgradcheck, gradcheck
911

1012

@@ -88,7 +90,7 @@ def test_detach(self):
8890

8991
def test_repr(self):
9092
self.assertTrue("layout=torch._mkldnn" in str(torch.randn((1, 2, 3, 4),
91-
dtype=torch.float, device=torch.device('cpu')).to_mkldnn()))
93+
dtype=torch.float, device=torch.device('cpu')).to_mkldnn()))
9294

9395
def test_conv2d(self):
9496
for groups in [1, 4]:
@@ -109,6 +111,8 @@ def test_conv2d(self):
109111
conv2d(x),
110112
mkldnn_conv2d(x.to_mkldnn()).to_dense())
111113

114+
self._test_serialization(mkldnn_conv2d, (x.to_mkldnn(),))
115+
112116
def test_relu(self):
113117
x = torch.randn((4, 5), dtype=torch.float32) * 10
114118
self.assertEqual(torch.relu(x), torch.relu(x.to_mkldnn()).to_dense())
@@ -172,6 +176,8 @@ def test_batch_norm2d(self):
172176
bn(x),
173177
mkldnn_bn(x.to_mkldnn()).to_dense())
174178

179+
self._test_serialization(mkldnn_bn, (x.to_mkldnn(),))
180+
175181
def test_add(self):
176182
N = torch.randint(3, 10, (1,)).item()
177183
C = torch.randint(3, 100, (1,)).item()
@@ -231,12 +237,22 @@ def test_linear(self):
231237
x = torch.randn(3, in_features, dtype=torch.float32) * 10
232238

233239
for bias in [True, False]:
234-
linear = torch.nn.Linear(in_features, out_features).float()
240+
linear = torch.nn.Linear(in_features, out_features, bias=bias).float()
235241
mkldnn_linear = mkldnn_utils.to_mkldnn(copy.deepcopy(linear))
236242
self.assertEqual(
237243
linear(x),
238244
mkldnn_linear(x.to_mkldnn()).to_dense())
239245

246+
self._test_serialization(mkldnn_linear, (x.to_mkldnn(),))
247+
248+
def _test_serialization(self, module, inputs):
249+
with TemporaryFileName() as fname:
250+
torch.jit.save(module, fname)
251+
loaded = torch.jit.load(fname)
252+
self.assertEqual(
253+
module(*inputs).to_dense(),
254+
loaded(*inputs).to_dense())
255+
240256

241257
if __name__ == '__main__':
242258
run_tests()

torch/utils/mkldnn.py

Lines changed: 141 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,149 @@
11
from __future__ import absolute_import, division, print_function, unicode_literals
2-
import functools
32

43
import torch
54

65

7-
def to_mkldnn(module):
8-
def t_fn(t):
9-
if t.is_floating_point():
10-
return t.to_mkldnn()
6+
class MkldnnLinear(torch.jit.ScriptModule):
7+
def __init__(self, dense_module):
8+
super(MkldnnLinear, self).__init__()
9+
self.register_buffer('weight', dense_module.weight.to_mkldnn())
10+
if dense_module.bias is not None:
11+
self.register_buffer('bias', dense_module.bias.to_mkldnn())
12+
else:
13+
# TODO: Remove this once ScriptModule supports registering None buffer
14+
self.register_buffer(
15+
'bias',
16+
torch.zeros([dense_module.weight.size(0)], dtype=torch.float).to_mkldnn())
17+
18+
@torch.jit.script_method
19+
def __getstate__(self):
20+
return (self.weight.to_dense(), self.bias.to_dense())
21+
22+
@torch.jit.script_method
23+
def __setstate__(self, state):
24+
# type: (Tuple[Tensor, Tensor]) -> None
25+
self.weight = state[0].to_mkldnn()
26+
self.bias = state[1].to_mkldnn()
27+
28+
@torch.jit.script_method
29+
def forward(self, x):
30+
return torch._C._nn.mkldnn_linear(x, self.weight, self.bias)
31+
32+
33+
class MkldnnConv2d(torch.jit.ScriptModule):
34+
__constants__ = ['stride', 'padding', 'dilation', 'groups']
35+
36+
def __init__(self, dense_module):
37+
super(MkldnnConv2d, self).__init__()
38+
39+
self.stride = dense_module.stride
40+
self.padding = dense_module.padding
41+
self.dilation = dense_module.dilation
42+
self.groups = dense_module.groups
43+
44+
self.register_buffer('weight', dense_module.weight.to_mkldnn())
45+
if dense_module.bias is not None:
46+
self.register_buffer('bias', dense_module.bias.to_mkldnn())
47+
else:
48+
# TODO: Remove this once ScriptModule supports registering None buffer
49+
self.register_buffer(
50+
'bias',
51+
torch.zeros([dense_module.weight.size(0)], dtype=torch.float).to_mkldnn())
52+
53+
@torch.jit.script_method
54+
def __getstate__(self):
55+
return (self.weight.to_dense(), self.bias.to_dense())
56+
57+
@torch.jit.script_method
58+
def __setstate__(self, state):
59+
# type: (Tuple[Tensor, Tensor]) -> None
60+
self.weight = torch._C._nn.mkldnn_reorder_conv2d_weight(
61+
state[0].to_mkldnn(),
62+
self.padding,
63+
self.stride,
64+
self.dilation,
65+
self.groups)
66+
self.bias = state[1].to_mkldnn()
67+
68+
@torch.jit.script_method
69+
def forward(self, x):
70+
return torch.conv2d(
71+
x,
72+
self.weight,
73+
self.bias,
74+
self.stride,
75+
self.padding,
76+
self.dilation,
77+
self.groups)
78+
1179

80+
class MkldnnBatchNorm2d(torch.jit.ScriptModule):
81+
__constants__ = ['exponential_average_factor', 'eps']
82+
83+
def __init__(self, dense_module):
84+
super(MkldnnBatchNorm2d, self).__init__()
85+
86+
assert(not dense_module.training)
87+
assert(dense_module.track_running_stats)
88+
assert(dense_module.affine)
89+
90+
if dense_module.momentum is None:
91+
self.exponential_average_factor = 0.0
92+
else:
93+
self.exponential_average_factor = dense_module.momentum
94+
self.eps = dense_module.eps
95+
96+
self.register_buffer('weight', dense_module.weight.to_mkldnn())
97+
self.register_buffer('bias', dense_module.bias.to_mkldnn())
98+
self.register_buffer('running_mean', dense_module.running_mean.to_mkldnn())
99+
self.register_buffer('running_var', dense_module.running_var.to_mkldnn())
100+
101+
@torch.jit.script_method
102+
def __getstate__(self):
103+
weight = self.weight.to_dense()
104+
bias = self.bias.to_dense()
105+
running_mean = self.running_mean.to_dense()
106+
running_var = self.running_var.to_dense()
107+
return (weight, bias, running_mean, running_var)
108+
109+
@torch.jit.script_method
110+
def __setstate__(self, state):
111+
# type: (Tuple[Tensor, Tensor, Tensor, Tensor]) -> None
112+
self.weight = state[0].to_mkldnn()
113+
self.bias = state[1].to_mkldnn()
114+
self.running_mean = state[2].to_mkldnn()
115+
self.running_var = state[3].to_mkldnn()
116+
117+
@torch.jit.script_method
118+
def forward(self, x):
119+
return torch.batch_norm(
120+
x,
121+
self.weight,
122+
self.bias,
123+
self.running_mean,
124+
self.running_var,
125+
False, # training
126+
self.exponential_average_factor,
127+
self.eps,
128+
False, # cuda_enabled
129+
)
130+
131+
132+
def to_mkldnn(module):
12133
def m_fn(m):
13-
# TODO: This is a temporary hack to work around the fact that
14-
# nn.Linear is decomposed into addmm/matmul. Later we will
15-
# change nn.Linear to directly call aten linear and we can
16-
# remove this patch
17134
if isinstance(m, torch.nn.Linear):
18-
m.forward = functools.partial(
19-
torch._C._nn.linear,
20-
weight=m.weight,
21-
bias=m.bias)
22-
23-
for param in m._parameters.values():
24-
if param is not None:
25-
# Tensors stored in modules are graph leaves, and we don't
26-
# want to create copy nodes, so we have to unpack the data.
27-
param.data = t_fn(param.data)
28-
if param._grad is not None:
29-
param._grad.data = t_fn(param._grad.data)
30-
31-
for key, buf in m._buffers.items():
32-
if buf is not None:
33-
m._buffers[key] = t_fn(buf)
34-
35-
if isinstance(m, torch.nn.Conv2d):
36-
m.weight.data = torch._C._nn.mkldnn_reorder_conv2d_weight(
37-
m.weight.data,
38-
m.padding,
39-
m.stride,
40-
m.dilation,
41-
m.groups)
42-
43-
return module.apply(m_fn)
135+
return MkldnnLinear(m)
136+
elif isinstance(m, torch.nn.Conv2d):
137+
return MkldnnConv2d(m)
138+
elif isinstance(m, torch.nn.BatchNorm2d):
139+
return MkldnnBatchNorm2d(m)
140+
else:
141+
return m
142+
143+
def m_fn_rec(m):
144+
new_m = m_fn(m)
145+
for name, sub_m in m.named_children():
146+
setattr(new_m, name, m_fn_rec(sub_m))
147+
return new_m
148+
149+
return m_fn_rec(module)

0 commit comments

Comments
 (0)