Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@
import socket
import time
from collections import OrderedDict
from contextlib import contextmanager
from functools import wraps
from itertools import product
from copy import deepcopy
from numbers import Number
import tempfile

import __main__
import errno
Expand Down Expand Up @@ -66,6 +68,24 @@ def run_tests(argv=UNITTEST_ARGS):
# Environment variable `IS_PYTORCH_CI` is set in `.jenkins/common.sh`.
IS_PYTORCH_CI = bool(os.environ.get('IS_PYTORCH_CI', 0))

if IS_WINDOWS:
@contextmanager
def TemporaryFileName():
# Ideally we would like to not have to manually delete the file, but NamedTemporaryFile
# opens the file, and it cannot be opened multiple times in Windows. To support Windows,
# close the file after creation and try to remove it manually
f = tempfile.NamedTemporaryFile(delete=False)
try:
f.close()
yield f.name
finally:
os.unlink(f.name)
else:
@contextmanager # noqa: T484
def TemporaryFileName():
with tempfile.NamedTemporaryFile() as f:
yield f.name


def _check_module_exists(name):
r"""Returns if a top-level module with :attr:`name` exists *without**
Expand Down
21 changes: 1 addition & 20 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from torch._six import inf, PY2, builtins, StringIO
from common_utils import TestCase, run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \
skipIfRocm, skipIfNoLapack, suppress_warnings, load_tests, IS_SANDCASTLE, \
freeze_rng_state, set_rng_seed, slowTest
freeze_rng_state, set_rng_seed, slowTest, TemporaryFileName
from common_nn import module_tests, new_module_tests, criterion_tests
from textwrap import dedent
from functools import wraps, reduce
Expand Down Expand Up @@ -84,25 +84,6 @@
WINDOWS = sys.platform == 'win32'


if WINDOWS:
@contextmanager
def TemporaryFileName():
# Ideally we would like to not have to manually delete the file, but NamedTemporaryFile
# opens the file, and it cannot be opened multiple times in Windows. To support Windows,
# close the file after creation and try to remove it manually
f = tempfile.NamedTemporaryFile(delete=False)
try:
f.close()
yield f.name
finally:
os.unlink(f.name)
else:
@contextmanager # noqa: T484
def TemporaryFileName():
with tempfile.NamedTemporaryFile() as f:
yield f.name


def LSTMCellF(input, hx, cx, *params):
return LSTMCell(input, (hx, cx), *params)

Expand Down
22 changes: 19 additions & 3 deletions test/test_mkldnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import unittest

import torch
import torch.jit
from torch.utils import mkldnn as mkldnn_utils
from common_utils import TestCase, run_tests
from common_utils import TestCase, run_tests, TemporaryFileName

from torch.autograd.gradcheck import gradgradcheck, gradcheck


Expand Down Expand Up @@ -88,7 +90,7 @@ def test_detach(self):

def test_repr(self):
self.assertTrue("layout=torch._mkldnn" in str(torch.randn((1, 2, 3, 4),
dtype=torch.float, device=torch.device('cpu')).to_mkldnn()))
dtype=torch.float, device=torch.device('cpu')).to_mkldnn()))

def test_conv2d(self):
for groups in [1, 4]:
Expand All @@ -109,6 +111,8 @@ def test_conv2d(self):
conv2d(x),
mkldnn_conv2d(x.to_mkldnn()).to_dense())

self._test_serialization(mkldnn_conv2d, (x.to_mkldnn(),))

def test_relu(self):
x = torch.randn((4, 5), dtype=torch.float32) * 10
self.assertEqual(torch.relu(x), torch.relu(x.to_mkldnn()).to_dense())
Expand Down Expand Up @@ -172,6 +176,8 @@ def test_batch_norm2d(self):
bn(x),
mkldnn_bn(x.to_mkldnn()).to_dense())

self._test_serialization(mkldnn_bn, (x.to_mkldnn(),))

def test_add(self):
N = torch.randint(3, 10, (1,)).item()
C = torch.randint(3, 100, (1,)).item()
Expand Down Expand Up @@ -231,12 +237,22 @@ def test_linear(self):
x = torch.randn(3, in_features, dtype=torch.float32) * 10

for bias in [True, False]:
linear = torch.nn.Linear(in_features, out_features).float()
linear = torch.nn.Linear(in_features, out_features, bias=bias).float()
mkldnn_linear = mkldnn_utils.to_mkldnn(copy.deepcopy(linear))
self.assertEqual(
linear(x),
mkldnn_linear(x.to_mkldnn()).to_dense())

self._test_serialization(mkldnn_linear, (x.to_mkldnn(),))

def _test_serialization(self, module, inputs):
with TemporaryFileName() as fname:
torch.jit.save(module, fname)
loaded = torch.jit.load(fname)
self.assertEqual(
module(*inputs).to_dense(),
loaded(*inputs).to_dense())


if __name__ == '__main__':
run_tests()
176 changes: 141 additions & 35 deletions torch/utils/mkldnn.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,149 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import functools

import torch


def to_mkldnn(module):
def t_fn(t):
if t.is_floating_point():
return t.to_mkldnn()
class MkldnnLinear(torch.jit.ScriptModule):
def __init__(self, dense_module):
super(MkldnnLinear, self).__init__()
self.register_buffer('weight', dense_module.weight.to_mkldnn())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bddppq , I doubt that why we regist weight to a buffer not a parameter, it is not suitable to training a mkldnn module if the weight is regist as a buffer. Can you tell me when we will use the jit save path?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bddppq , I have tried regist weight to a parameter to run backward, I found backward operation can be run, but the jit save and load have some problem, can you give me some advice? Thanks!

if dense_module.bias is not None:
self.register_buffer('bias', dense_module.bias.to_mkldnn())
else:
# TODO: Remove this once ScriptModule supports registering None buffer
self.register_buffer(
'bias',
torch.zeros([dense_module.weight.size(0)], dtype=torch.float).to_mkldnn())

@torch.jit.script_method
def __getstate__(self):
return (self.weight.to_dense(), self.bias.to_dense())

@torch.jit.script_method
def __setstate__(self, state):
# type: (Tuple[Tensor, Tensor]) -> None
self.weight = state[0].to_mkldnn()
self.bias = state[1].to_mkldnn()

@torch.jit.script_method
def forward(self, x):
return torch._C._nn.mkldnn_linear(x, self.weight, self.bias)


class MkldnnConv2d(torch.jit.ScriptModule):
__constants__ = ['stride', 'padding', 'dilation', 'groups']

def __init__(self, dense_module):
super(MkldnnConv2d, self).__init__()

self.stride = dense_module.stride
self.padding = dense_module.padding
self.dilation = dense_module.dilation
self.groups = dense_module.groups

self.register_buffer('weight', dense_module.weight.to_mkldnn())
if dense_module.bias is not None:
self.register_buffer('bias', dense_module.bias.to_mkldnn())
else:
# TODO: Remove this once ScriptModule supports registering None buffer
self.register_buffer(
'bias',
torch.zeros([dense_module.weight.size(0)], dtype=torch.float).to_mkldnn())

@torch.jit.script_method
def __getstate__(self):
return (self.weight.to_dense(), self.bias.to_dense())

@torch.jit.script_method
def __setstate__(self, state):
# type: (Tuple[Tensor, Tensor]) -> None
self.weight = torch._C._nn.mkldnn_reorder_conv2d_weight(
state[0].to_mkldnn(),
self.padding,
self.stride,
self.dilation,
self.groups)
self.bias = state[1].to_mkldnn()

@torch.jit.script_method
def forward(self, x):
return torch.conv2d(
x,
self.weight,
self.bias,
self.stride,
self.padding,
self.dilation,
self.groups)


class MkldnnBatchNorm2d(torch.jit.ScriptModule):
__constants__ = ['exponential_average_factor', 'eps']

def __init__(self, dense_module):
super(MkldnnBatchNorm2d, self).__init__()

assert(not dense_module.training)
assert(dense_module.track_running_stats)
assert(dense_module.affine)

if dense_module.momentum is None:
self.exponential_average_factor = 0.0
else:
self.exponential_average_factor = dense_module.momentum
self.eps = dense_module.eps

self.register_buffer('weight', dense_module.weight.to_mkldnn())
self.register_buffer('bias', dense_module.bias.to_mkldnn())
self.register_buffer('running_mean', dense_module.running_mean.to_mkldnn())
self.register_buffer('running_var', dense_module.running_var.to_mkldnn())

@torch.jit.script_method
def __getstate__(self):
weight = self.weight.to_dense()
bias = self.bias.to_dense()
running_mean = self.running_mean.to_dense()
running_var = self.running_var.to_dense()
return (weight, bias, running_mean, running_var)

@torch.jit.script_method
def __setstate__(self, state):
# type: (Tuple[Tensor, Tensor, Tensor, Tensor]) -> None
self.weight = state[0].to_mkldnn()
self.bias = state[1].to_mkldnn()
self.running_mean = state[2].to_mkldnn()
self.running_var = state[3].to_mkldnn()

@torch.jit.script_method
def forward(self, x):
return torch.batch_norm(
x,
self.weight,
self.bias,
self.running_mean,
self.running_var,
False, # training
self.exponential_average_factor,
self.eps,
False, # cuda_enabled
)


def to_mkldnn(module):
def m_fn(m):
# TODO: This is a temporary hack to work around the fact that
# nn.Linear is decomposed into addmm/matmul. Later we will
# change nn.Linear to directly call aten linear and we can
# remove this patch
if isinstance(m, torch.nn.Linear):
m.forward = functools.partial(
torch._C._nn.linear,
weight=m.weight,
bias=m.bias)

for param in m._parameters.values():
if param is not None:
# Tensors stored in modules are graph leaves, and we don't
# want to create copy nodes, so we have to unpack the data.
param.data = t_fn(param.data)
if param._grad is not None:
param._grad.data = t_fn(param._grad.data)

for key, buf in m._buffers.items():
if buf is not None:
m._buffers[key] = t_fn(buf)

if isinstance(m, torch.nn.Conv2d):
m.weight.data = torch._C._nn.mkldnn_reorder_conv2d_weight(
m.weight.data,
m.padding,
m.stride,
m.dilation,
m.groups)

return module.apply(m_fn)
return MkldnnLinear(m)
elif isinstance(m, torch.nn.Conv2d):
return MkldnnConv2d(m)
elif isinstance(m, torch.nn.BatchNorm2d):
return MkldnnBatchNorm2d(m)
else:
return m

def m_fn_rec(m):
new_m = m_fn(m)
for name, sub_m in m.named_children():
setattr(new_m, name, m_fn_rec(sub_m))
return new_m

return m_fn_rec(module)