-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Add support for save and load mkldnn modules #20799
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6074daa
d2c049b
e7b7418
0c069ac
d53f348
573d837
481098d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,43 +1,149 @@ | ||
| from __future__ import absolute_import, division, print_function, unicode_literals | ||
| import functools | ||
|
|
||
| import torch | ||
|
|
||
|
|
||
| def to_mkldnn(module): | ||
| def t_fn(t): | ||
| if t.is_floating_point(): | ||
| return t.to_mkldnn() | ||
| class MkldnnLinear(torch.jit.ScriptModule): | ||
| def __init__(self, dense_module): | ||
| super(MkldnnLinear, self).__init__() | ||
| self.register_buffer('weight', dense_module.weight.to_mkldnn()) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @bddppq , I doubt that why we regist weight to a buffer not a parameter, it is not suitable to training a mkldnn module if the weight is regist as a buffer. Can you tell me when we will use the jit save path?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @bddppq , I have tried regist weight to a parameter to run backward, I found backward operation can be run, but the jit save and load have some problem, can you give me some advice? Thanks! |
||
| if dense_module.bias is not None: | ||
| self.register_buffer('bias', dense_module.bias.to_mkldnn()) | ||
| else: | ||
| # TODO: Remove this once ScriptModule supports registering None buffer | ||
bddppq marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| self.register_buffer( | ||
| 'bias', | ||
| torch.zeros([dense_module.weight.size(0)], dtype=torch.float).to_mkldnn()) | ||
|
|
||
| @torch.jit.script_method | ||
| def __getstate__(self): | ||
| return (self.weight.to_dense(), self.bias.to_dense()) | ||
bddppq marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| @torch.jit.script_method | ||
| def __setstate__(self, state): | ||
| # type: (Tuple[Tensor, Tensor]) -> None | ||
| self.weight = state[0].to_mkldnn() | ||
| self.bias = state[1].to_mkldnn() | ||
|
|
||
| @torch.jit.script_method | ||
| def forward(self, x): | ||
| return torch._C._nn.mkldnn_linear(x, self.weight, self.bias) | ||
|
|
||
|
|
||
| class MkldnnConv2d(torch.jit.ScriptModule): | ||
| __constants__ = ['stride', 'padding', 'dilation', 'groups'] | ||
|
|
||
| def __init__(self, dense_module): | ||
| super(MkldnnConv2d, self).__init__() | ||
|
|
||
| self.stride = dense_module.stride | ||
| self.padding = dense_module.padding | ||
| self.dilation = dense_module.dilation | ||
| self.groups = dense_module.groups | ||
|
|
||
| self.register_buffer('weight', dense_module.weight.to_mkldnn()) | ||
| if dense_module.bias is not None: | ||
| self.register_buffer('bias', dense_module.bias.to_mkldnn()) | ||
| else: | ||
| # TODO: Remove this once ScriptModule supports registering None buffer | ||
| self.register_buffer( | ||
| 'bias', | ||
| torch.zeros([dense_module.weight.size(0)], dtype=torch.float).to_mkldnn()) | ||
|
|
||
| @torch.jit.script_method | ||
| def __getstate__(self): | ||
| return (self.weight.to_dense(), self.bias.to_dense()) | ||
|
|
||
| @torch.jit.script_method | ||
| def __setstate__(self, state): | ||
| # type: (Tuple[Tensor, Tensor]) -> None | ||
| self.weight = torch._C._nn.mkldnn_reorder_conv2d_weight( | ||
| state[0].to_mkldnn(), | ||
| self.padding, | ||
| self.stride, | ||
| self.dilation, | ||
| self.groups) | ||
| self.bias = state[1].to_mkldnn() | ||
|
|
||
| @torch.jit.script_method | ||
| def forward(self, x): | ||
| return torch.conv2d( | ||
| x, | ||
| self.weight, | ||
| self.bias, | ||
| self.stride, | ||
| self.padding, | ||
| self.dilation, | ||
| self.groups) | ||
|
|
||
|
|
||
| class MkldnnBatchNorm2d(torch.jit.ScriptModule): | ||
| __constants__ = ['exponential_average_factor', 'eps'] | ||
|
|
||
| def __init__(self, dense_module): | ||
| super(MkldnnBatchNorm2d, self).__init__() | ||
|
|
||
| assert(not dense_module.training) | ||
bddppq marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| assert(dense_module.track_running_stats) | ||
| assert(dense_module.affine) | ||
|
|
||
| if dense_module.momentum is None: | ||
bddppq marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| self.exponential_average_factor = 0.0 | ||
| else: | ||
| self.exponential_average_factor = dense_module.momentum | ||
| self.eps = dense_module.eps | ||
|
|
||
| self.register_buffer('weight', dense_module.weight.to_mkldnn()) | ||
| self.register_buffer('bias', dense_module.bias.to_mkldnn()) | ||
| self.register_buffer('running_mean', dense_module.running_mean.to_mkldnn()) | ||
| self.register_buffer('running_var', dense_module.running_var.to_mkldnn()) | ||
|
|
||
| @torch.jit.script_method | ||
| def __getstate__(self): | ||
| weight = self.weight.to_dense() | ||
| bias = self.bias.to_dense() | ||
| running_mean = self.running_mean.to_dense() | ||
| running_var = self.running_var.to_dense() | ||
| return (weight, bias, running_mean, running_var) | ||
|
|
||
| @torch.jit.script_method | ||
| def __setstate__(self, state): | ||
| # type: (Tuple[Tensor, Tensor, Tensor, Tensor]) -> None | ||
| self.weight = state[0].to_mkldnn() | ||
| self.bias = state[1].to_mkldnn() | ||
| self.running_mean = state[2].to_mkldnn() | ||
| self.running_var = state[3].to_mkldnn() | ||
|
|
||
| @torch.jit.script_method | ||
| def forward(self, x): | ||
| return torch.batch_norm( | ||
| x, | ||
| self.weight, | ||
| self.bias, | ||
| self.running_mean, | ||
| self.running_var, | ||
| False, # training | ||
| self.exponential_average_factor, | ||
| self.eps, | ||
| False, # cuda_enabled | ||
| ) | ||
|
|
||
|
|
||
| def to_mkldnn(module): | ||
| def m_fn(m): | ||
| # TODO: This is a temporary hack to work around the fact that | ||
| # nn.Linear is decomposed into addmm/matmul. Later we will | ||
| # change nn.Linear to directly call aten linear and we can | ||
| # remove this patch | ||
| if isinstance(m, torch.nn.Linear): | ||
| m.forward = functools.partial( | ||
| torch._C._nn.linear, | ||
| weight=m.weight, | ||
| bias=m.bias) | ||
|
|
||
| for param in m._parameters.values(): | ||
| if param is not None: | ||
| # Tensors stored in modules are graph leaves, and we don't | ||
| # want to create copy nodes, so we have to unpack the data. | ||
| param.data = t_fn(param.data) | ||
| if param._grad is not None: | ||
| param._grad.data = t_fn(param._grad.data) | ||
|
|
||
| for key, buf in m._buffers.items(): | ||
| if buf is not None: | ||
| m._buffers[key] = t_fn(buf) | ||
|
|
||
| if isinstance(m, torch.nn.Conv2d): | ||
| m.weight.data = torch._C._nn.mkldnn_reorder_conv2d_weight( | ||
| m.weight.data, | ||
| m.padding, | ||
| m.stride, | ||
| m.dilation, | ||
| m.groups) | ||
|
|
||
| return module.apply(m_fn) | ||
| return MkldnnLinear(m) | ||
| elif isinstance(m, torch.nn.Conv2d): | ||
| return MkldnnConv2d(m) | ||
| elif isinstance(m, torch.nn.BatchNorm2d): | ||
| return MkldnnBatchNorm2d(m) | ||
| else: | ||
| return m | ||
|
|
||
| def m_fn_rec(m): | ||
| new_m = m_fn(m) | ||
| for name, sub_m in m.named_children(): | ||
| setattr(new_m, name, m_fn_rec(sub_m)) | ||
| return new_m | ||
|
|
||
| return m_fn_rec(module) | ||
Uh oh!
There was an error while loading. Please reload this page.