Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
034c25e
[Quant] Add fused LinearTanh module for onednn backend
Xia-Weiwen Nov 12, 2022
0aa39b7
Update on "[Quant] Add fused LinearTanh module for onednn backend"
Xia-Weiwen Nov 17, 2022
c27a73f
Update on "[Quant] Add fused LinearTanh module for onednn backend"
Xia-Weiwen Nov 17, 2022
fcb0ec8
Update on "[Quant] Add fused LinearTanh module for onednn backend"
Xia-Weiwen Nov 17, 2022
1111c05
Update on "[Quant] Add fused LinearTanh module for onednn backend"
Xia-Weiwen Nov 20, 2022
e975032
Update on "[Quant] Add fused LinearTanh module for onednn backend"
Xia-Weiwen Nov 21, 2022
795fac0
Update on "[Quant] Add fused LinearTanh module for onednn backend"
Xia-Weiwen Nov 22, 2022
ba8b917
Update on "[Quant] Add fused LinearTanh module for onednn backend"
Xia-Weiwen Nov 24, 2022
c7442cb
Update on "[Quant] Add fused LinearTanh module for onednn backend"
Xia-Weiwen Nov 29, 2022
e95cf9a
Update on "[Quant] Add fused LinearTanh module for onednn backend"
Xia-Weiwen Nov 30, 2022
a9c57d8
Update on "[Quant] Add fused LinearTanh module for onednn backend"
Xia-Weiwen Nov 30, 2022
fffce85
Update on "[Quant] Add fused LinearTanh module for onednn backend"
Xia-Weiwen Dec 5, 2022
b0322bf
Update on "[Quant] Add fused LinearTanh module for onednn backend"
Xia-Weiwen Dec 5, 2022
6092a74
Update on "[Quant] Add fused LinearTanh module for onednn backend"
Xia-Weiwen Dec 12, 2022
ab204be
Update on "[Quant] Add fused LinearTanh module for onednn backend"
Xia-Weiwen Dec 15, 2022
9f87f1f
Update on "[Quant] Add fused LinearTanh module for onednn backend"
Xia-Weiwen Dec 15, 2022
ef185d4
Update on "[Quant] Add fused LinearTanh module for onednn backend"
Xia-Weiwen Dec 15, 2022
a263cb9
Update on "[Quant] Add fused LinearTanh module for onednn backend"
Xia-Weiwen Dec 15, 2022
2395d06
Update on "[Quant] Add fused LinearTanh module for onednn backend"
Xia-Weiwen Dec 15, 2022
90cde3d
Update on "[Quant] Add fused LinearTanh module for onednn backend"
Xia-Weiwen Dec 15, 2022
c0143c7
Update on "[Quant] Add fused LinearTanh module for onednn backend"
Xia-Weiwen Dec 15, 2022
f6aaf07
Update on "[Quant] Add fused LinearTanh module for onednn backend"
Xia-Weiwen Dec 15, 2022
6b22bfa
Update on "[Quant] Add fused LinearTanh module for onednn backend"
Xia-Weiwen Dec 15, 2022
d46f4cb
Update on "[Quant] Add fused LinearTanh module for onednn backend"
Xia-Weiwen Dec 16, 2022
bbb9bbb
Update on "[Quant] Add fused LinearTanh module for onednn backend"
Xia-Weiwen Dec 16, 2022
0ce88c7
Update on "[Quant] Add fused LinearTanh module for onednn backend"
Xia-Weiwen Dec 16, 2022
1aa42eb
Update on "[Quant] Add fused LinearTanh module for onednn backend"
Xia-Weiwen Dec 16, 2022
ab6b724
Update on "[Quant] Add fused LinearTanh module for onednn backend"
Xia-Weiwen Dec 18, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions test/quantization/core/test_quantized_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1087,6 +1087,24 @@ def test_linear_leaky_relu(self):
batch_size, in_features, out_features, use_bias,
per_channel, negative_slope=neg_slope)

@skipIfNoONEDNN
def test_linear_tanh(self):
"""test API functionality for nn.intrinsic.quantized.linear_tanh"""
with override_quantized_engine('onednn'):
options = itertools.product(
[1, 5], # batch size
[16, 32], # in_features
[4, 8], # out_features
[True, False], # use_bias
[True, False]) # negative slope
for (batch_size, in_features, out_features, use_bias,
per_channel) in options:
self._test_linear_api_impl(
nniq.LinearTanh, 'QuantizedLinearTanh',
torch.ops.quantized.linear_tanh,
batch_size, in_features, out_features, use_bias,
per_channel)

class TestDynamicQuantizedModule(QuantizationTestCase):
def _test_qconv_impl(self, q_mod, dq_mod, dim, dtype, bias):
in_channels = 3
Expand Down
1 change: 1 addition & 0 deletions torch/ao/nn/intrinsic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
'BNReLU3d',
'LinearBn1d',
'LinearLeakyReLU',
'LinearTanh',
]

# We are exposing all subpackages to the end-user.
Expand Down
2 changes: 2 additions & 0 deletions torch/ao/nn/intrinsic/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .fused import BNReLU3d
from .fused import LinearBn1d
from .fused import LinearLeakyReLU
from .fused import LinearTanh


__all__ = [
Expand All @@ -30,4 +31,5 @@
'BNReLU3d',
'LinearBn1d',
'LinearLeakyReLU',
'LinearTanh',
]
11 changes: 10 additions & 1 deletion torch/ao/nn/intrinsic/modules/fused.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

__all__ = ['ConvReLU1d', 'ConvReLU2d', 'ConvReLU3d', 'LinearReLU', 'ConvBn1d', 'ConvBn2d',
'ConvBnReLU1d', 'ConvBnReLU2d', 'ConvBn3d', 'ConvBnReLU3d', 'BNReLU2d', 'BNReLU3d',
'LinearBn1d', 'LinearLeakyReLU']
'LinearBn1d', 'LinearLeakyReLU', 'LinearTanh']
# Used for identifying intrinsic modules used in quantization
class _FusedModule(torch.nn.Sequential):
pass
Expand Down Expand Up @@ -135,3 +135,12 @@ def __init__(self, linear, leaky_relu):
'Incorrect types for input modules{}{}'.format(
type(linear), type(leaky_relu))
super().__init__(linear, leaky_relu)

class LinearTanh(_FusedModule):
r"""This is a sequential container which calls the Linear and Tanh modules.
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, linear, tanh):
assert type(linear) == Linear and type(tanh) == torch.nn.Tanh, \
'Incorrect types for input modules{}{}'.format(
type(linear), type(tanh))
super().__init__(linear, tanh)
1 change: 1 addition & 0 deletions torch/ao/nn/intrinsic/quantized/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@
'ConvReLU3d',
'LinearReLU',
'LinearLeakyReLU',
'LinearTanh',
]
3 changes: 2 additions & 1 deletion torch/ao/nn/intrinsic/quantized/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .linear_relu import LinearReLU, LinearLeakyReLU
from .linear_relu import LinearReLU, LinearLeakyReLU, LinearTanh
from .conv_relu import ConvReLU1d, ConvReLU2d, ConvReLU3d
from .bn_relu import BNReLU2d, BNReLU3d

Expand All @@ -10,4 +10,5 @@
'BNReLU2d',
'BNReLU3d',
'LinearLeakyReLU',
'LinearTanh',
]
65 changes: 64 additions & 1 deletion torch/ao/nn/intrinsic/quantized/modules/linear_relu.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
__all__ = [
"LinearReLU",
"LinearLeakyReLU",
"LinearTanh",
]

class LinearReLU(nnq.Linear):
Expand Down Expand Up @@ -107,7 +108,69 @@ def from_reference(cls, ref_mod, output_scale, output_zero_point):
leaky_relu.negative_slope)
qweight = linear.get_quantized_weight()
qlinear_leaky_relu.set_weight_bias(qweight, linear.bias)

qlinear_leaky_relu.scale = float(output_scale)
qlinear_leaky_relu.zero_point = int(output_zero_point)
return qlinear_leaky_relu

class LinearTanh(nnq.Linear):
r"""
A LinearTanh module fused from Linear and Tanh modules

We adopt the same interface as :class:`torch.ao.nn.quantized.Linear`.

Attributes:
Same as torch.ao.nn.quantized.Linear

Examples::

>>> # xdoctest: +SKIP
>>> m = nn.intrinsic.LinearTanh(20, 30)
>>> input = torch.randn(128, 20)
>>> output = m(input)
>>> print(output.size())
torch.Size([128, 30])
"""
_FLOAT_MODULE = nni.LinearTanh

def __init__(self, in_features, out_features, bias=True, dtype=torch.qint8):
super().__init__(in_features, out_features, bias, dtype)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.ops.quantized.linear_tanh(
x, self._packed_params._packed_params, self.scale, self.zero_point)

def _get_name(self):
return 'QuantizedLinearTanh'

@classmethod
def from_float(cls, mod):
assert type(mod) == nni.LinearTanh, 'Input float module should be LinearTanh'
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
activation_post_process = mod.activation_post_process
mod = mod[0]
weight_post_process = mod.qconfig.weight()
weight_post_process(mod.weight)
dtype = weight_post_process.dtype
act_scale, act_zp = activation_post_process.calculate_qparams() # type: ignore[union-attr,operator]
assert dtype == torch.qint8, 'Weight observer must have dtype torch.qint8'
qweight = _quantize_weight(mod.weight.float(), weight_post_process)
qlinear_tanh = cls(
mod.in_features,
mod.out_features,
dtype=dtype)
qlinear_tanh.set_weight_bias(qweight, mod.bias)
qlinear_tanh.scale = float(act_scale)
qlinear_tanh.zero_point = int(act_zp)
return qlinear_tanh

@classmethod
def from_reference(cls, ref_mod, output_scale, output_zero_point):
linear = ref_mod[0]
qlinear_tanh = cls(
linear.in_features,
linear.out_features)
qweight = linear.get_quantized_weight()
qlinear_tanh.set_weight_bias(qweight, linear.bias)
qlinear_tanh.scale = float(output_scale)
qlinear_tanh.zero_point = int(output_zero_point)
return qlinear_tanh