|
1 | 1 | from __future__ import absolute_import, division, print_function, unicode_literals |
2 | | -import functools |
3 | 2 |
|
4 | 3 | import torch |
5 | 4 |
|
6 | 5 |
|
7 | | -def to_mkldnn(module): |
8 | | - def t_fn(t): |
9 | | - if t.is_floating_point(): |
10 | | - return t.to_mkldnn() |
| 6 | +class MkldnnLinear(torch.jit.ScriptModule): |
| 7 | + def __init__(self, dense_module): |
| 8 | + super(MkldnnLinear, self).__init__() |
| 9 | + self.register_buffer('weight', dense_module.weight.to_mkldnn()) |
| 10 | + if dense_module.bias is not None: |
| 11 | + self.register_buffer('bias', dense_module.bias.to_mkldnn()) |
| 12 | + else: |
| 13 | + # TODO: Remove this once ScriptModule supports registering None buffer |
| 14 | + self.register_buffer( |
| 15 | + 'bias', |
| 16 | + torch.zeros([dense_module.weight.size(0)], dtype=torch.float).to_mkldnn()) |
| 17 | + |
| 18 | + @torch.jit.script_method |
| 19 | + def __getstate__(self): |
| 20 | + return (self.weight.to_dense(), self.bias.to_dense()) |
| 21 | + |
| 22 | + @torch.jit.script_method |
| 23 | + def __setstate__(self, state): |
| 24 | + # type: (Tuple[Tensor, Tensor]) -> None |
| 25 | + self.weight = state[0].to_mkldnn() |
| 26 | + self.bias = state[1].to_mkldnn() |
| 27 | + |
| 28 | + @torch.jit.script_method |
| 29 | + def forward(self, x): |
| 30 | + return torch._C._nn.mkldnn_linear(x, self.weight, self.bias) |
| 31 | + |
| 32 | + |
| 33 | +class MkldnnConv2d(torch.jit.ScriptModule): |
| 34 | + __constants__ = ['stride', 'padding', 'dilation', 'groups'] |
| 35 | + |
| 36 | + def __init__(self, dense_module): |
| 37 | + super(MkldnnConv2d, self).__init__() |
| 38 | + |
| 39 | + self.stride = dense_module.stride |
| 40 | + self.padding = dense_module.padding |
| 41 | + self.dilation = dense_module.dilation |
| 42 | + self.groups = dense_module.groups |
| 43 | + |
| 44 | + self.register_buffer('weight', dense_module.weight.to_mkldnn()) |
| 45 | + if dense_module.bias is not None: |
| 46 | + self.register_buffer('bias', dense_module.bias.to_mkldnn()) |
| 47 | + else: |
| 48 | + # TODO: Remove this once ScriptModule supports registering None buffer |
| 49 | + self.register_buffer( |
| 50 | + 'bias', |
| 51 | + torch.zeros([dense_module.weight.size(0)], dtype=torch.float).to_mkldnn()) |
| 52 | + |
| 53 | + @torch.jit.script_method |
| 54 | + def __getstate__(self): |
| 55 | + return (self.weight.to_dense(), self.bias.to_dense()) |
| 56 | + |
| 57 | + @torch.jit.script_method |
| 58 | + def __setstate__(self, state): |
| 59 | + # type: (Tuple[Tensor, Tensor]) -> None |
| 60 | + self.weight = torch._C._nn.mkldnn_reorder_conv2d_weight( |
| 61 | + state[0].to_mkldnn(), |
| 62 | + self.padding, |
| 63 | + self.stride, |
| 64 | + self.dilation, |
| 65 | + self.groups) |
| 66 | + self.bias = state[1].to_mkldnn() |
| 67 | + |
| 68 | + @torch.jit.script_method |
| 69 | + def forward(self, x): |
| 70 | + return torch.conv2d( |
| 71 | + x, |
| 72 | + self.weight, |
| 73 | + self.bias, |
| 74 | + self.stride, |
| 75 | + self.padding, |
| 76 | + self.dilation, |
| 77 | + self.groups) |
| 78 | + |
11 | 79 |
|
| 80 | +class MkldnnBatchNorm2d(torch.jit.ScriptModule): |
| 81 | + __constants__ = ['exponential_average_factor', 'eps'] |
| 82 | + |
| 83 | + def __init__(self, dense_module): |
| 84 | + super(MkldnnBatchNorm2d, self).__init__() |
| 85 | + |
| 86 | + assert(not dense_module.training) |
| 87 | + assert(dense_module.track_running_stats) |
| 88 | + assert(dense_module.affine) |
| 89 | + |
| 90 | + if dense_module.momentum is None: |
| 91 | + self.exponential_average_factor = 0.0 |
| 92 | + else: |
| 93 | + self.exponential_average_factor = dense_module.momentum |
| 94 | + self.eps = dense_module.eps |
| 95 | + |
| 96 | + self.register_buffer('weight', dense_module.weight.to_mkldnn()) |
| 97 | + self.register_buffer('bias', dense_module.bias.to_mkldnn()) |
| 98 | + self.register_buffer('running_mean', dense_module.running_mean.to_mkldnn()) |
| 99 | + self.register_buffer('running_var', dense_module.running_var.to_mkldnn()) |
| 100 | + |
| 101 | + @torch.jit.script_method |
| 102 | + def __getstate__(self): |
| 103 | + weight = self.weight.to_dense() |
| 104 | + bias = self.bias.to_dense() |
| 105 | + running_mean = self.running_mean.to_dense() |
| 106 | + running_var = self.running_var.to_dense() |
| 107 | + return (weight, bias, running_mean, running_var) |
| 108 | + |
| 109 | + @torch.jit.script_method |
| 110 | + def __setstate__(self, state): |
| 111 | + # type: (Tuple[Tensor, Tensor, Tensor, Tensor]) -> None |
| 112 | + self.weight = state[0].to_mkldnn() |
| 113 | + self.bias = state[1].to_mkldnn() |
| 114 | + self.running_mean = state[2].to_mkldnn() |
| 115 | + self.running_var = state[3].to_mkldnn() |
| 116 | + |
| 117 | + @torch.jit.script_method |
| 118 | + def forward(self, x): |
| 119 | + return torch.batch_norm( |
| 120 | + x, |
| 121 | + self.weight, |
| 122 | + self.bias, |
| 123 | + self.running_mean, |
| 124 | + self.running_var, |
| 125 | + False, # training |
| 126 | + self.exponential_average_factor, |
| 127 | + self.eps, |
| 128 | + False, # cuda_enabled |
| 129 | + ) |
| 130 | + |
| 131 | + |
| 132 | +def to_mkldnn(module): |
12 | 133 | def m_fn(m): |
13 | | - # TODO: This is a temporary hack to work around the fact that |
14 | | - # nn.Linear is decomposed into addmm/matmul. Later we will |
15 | | - # change nn.Linear to directly call aten linear and we can |
16 | | - # remove this patch |
17 | 134 | if isinstance(m, torch.nn.Linear): |
18 | | - m.forward = functools.partial( |
19 | | - torch._C._nn.linear, |
20 | | - weight=m.weight, |
21 | | - bias=m.bias) |
22 | | - |
23 | | - for param in m._parameters.values(): |
24 | | - if param is not None: |
25 | | - # Tensors stored in modules are graph leaves, and we don't |
26 | | - # want to create copy nodes, so we have to unpack the data. |
27 | | - param.data = t_fn(param.data) |
28 | | - if param._grad is not None: |
29 | | - param._grad.data = t_fn(param._grad.data) |
30 | | - |
31 | | - for key, buf in m._buffers.items(): |
32 | | - if buf is not None: |
33 | | - m._buffers[key] = t_fn(buf) |
34 | | - |
35 | | - if isinstance(m, torch.nn.Conv2d): |
36 | | - m.weight.data = torch._C._nn.mkldnn_reorder_conv2d_weight( |
37 | | - m.weight.data, |
38 | | - m.padding, |
39 | | - m.stride, |
40 | | - m.dilation, |
41 | | - m.groups) |
42 | | - |
43 | | - return module.apply(m_fn) |
| 135 | + return MkldnnLinear(m) |
| 136 | + elif isinstance(m, torch.nn.Conv2d): |
| 137 | + return MkldnnConv2d(m) |
| 138 | + elif isinstance(m, torch.nn.BatchNorm2d): |
| 139 | + return MkldnnBatchNorm2d(m) |
| 140 | + else: |
| 141 | + return m |
| 142 | + |
| 143 | + def m_fn_rec(m): |
| 144 | + new_m = m_fn(m) |
| 145 | + for name, sub_m in m.named_children(): |
| 146 | + setattr(new_m, name, m_fn_rec(sub_m)) |
| 147 | + return new_m |
| 148 | + |
| 149 | + return m_fn_rec(module) |
0 commit comments