22from torch .library import Library , impl
33from torch .ao .quantization .utils import determine_qparams , validate_qmin_qmax
44from typing import Tuple
5+ from torch ._decomp import register_decomposition
6+
7+ def _quantize_per_tensor_impl (
8+ input : torch .Tensor ,
9+ scale : float ,
10+ zero_point : int ,
11+ quant_min : int ,
12+ quant_max : int ,
13+ dtype : torch .dtype ,
14+ ) -> torch .Tensor :
15+ inv_scale = 1.0 / scale
16+ return torch .clamp (
17+ torch .round (input * inv_scale ) + zero_point , quant_min , quant_max
18+ ).to (dtype )
19+
20+ def _dequantize_per_tensor_impl (
21+ input : torch .Tensor ,
22+ scale : float ,
23+ zero_point : int ,
24+ quant_min : int ,
25+ quant_max : int ,
26+ dtype : torch .dtype ,
27+ ) -> torch .Tensor :
28+ return (input .to (torch .float32 ) - zero_point ) * scale
29+
530
631
732# Note: decomposed means decomposed quantized tensor, using decomposed so that the
@@ -59,8 +84,18 @@ def quantize_per_tensor(
5984 assert input .dtype == torch .float32 , f"Expecting input to have dtype torch.float32, but got dtype: { input .dtype } "
6085 _quant_min_max_bounds_check (quant_min , quant_max , dtype )
6186
62- inv_scale = 1.0 / scale
63- return torch .clamp (torch .round (input * inv_scale ) + zero_point , quant_min , quant_max ).to (dtype )
87+ return _quantize_per_tensor_impl (input , scale , zero_point , quant_min , quant_max , dtype )
88+
89+ @register_decomposition (torch .ops .quantized_decomposed .quantize_per_tensor )
90+ def quantize_per_tensor_decomp_impl (
91+ input : torch .Tensor ,
92+ scale : float ,
93+ zero_point : int ,
94+ quant_min : int ,
95+ quant_max : int ,
96+ dtype : torch .dtype ,
97+ ) -> torch .Tensor :
98+ return _quantize_per_tensor_impl (input , scale , zero_point , quant_min , quant_max , dtype )
6499
65100quantized_decomposed_lib .define (
66101 "quantize_per_tensor.tensor(Tensor input, Tensor scale, Tensor zero_point, "
@@ -82,15 +117,19 @@ def quantize_per_tensor_tensor(
82117 """
83118 assert zero_point .numel () == 1 , f"Exepecting zero_point tensor to be one element, but received : { zero_point .numel ()} "
84119 assert scale .numel () == 1 , f"Exepecting scale tensor to be one element, but received : { scale .numel ()} "
85- return quantize_per_tensor (input , scale .item (), zero_point .item (), quant_min , quant_max , dtype )
86-
87- @impl (quantized_decomposed_lib , "quantize_per_tensor.tensor" , "Meta" )
88- def quantize_per_tensor_tensor_meta (input , scale , zero_point , quant_min , quant_max , dtype ):
89- assert zero_point .numel () == 1 , f"Exepecting zero_point tensor to be one element, but received : { zero_point .numel ()} "
90- assert scale .numel () == 1 , f"Exepecting scale tensor to be one element, but received : { scale .numel ()} "
91- assert input .dtype == torch .float32 , f"Expecting input to have dtype torch.float32, but got dtype: { input .dtype } "
92- _quant_min_max_bounds_check (quant_min , quant_max , dtype )
93- return torch .empty_like (input , dtype = dtype )
120+ return _quantize_per_tensor_impl (
121+ input , scale .item (), zero_point .item (), quant_min , quant_max , dtype ) # type: ignore[arg-type]
122+
123+ @register_decomposition (torch .ops .quantized_decomposed .quantize_per_tensor .tensor )
124+ def quantize_per_tensor_tensor_decomp_impl (
125+ input : torch .Tensor ,
126+ scale : torch .Tensor ,
127+ zero_point : torch .Tensor ,
128+ quant_min : int ,
129+ quant_max : int ,
130+ dtype : torch .dtype ,
131+ ) -> torch .Tensor :
132+ return _quantize_per_tensor_impl (input , scale .item (), zero_point .item (), quant_min , quant_max , dtype ) # type: ignore[arg-type]
94133
95134# Note: quant_min/quant_max/dtype are not used in the operator, but for now it's kept in
96135# the signature as metadata for the input Tensor, this might be useful for pattern
@@ -138,11 +177,22 @@ def dequantize_per_tensor(
138177 # TODO: investigate why
139178 # (input - zero_point).to(torch.float32) * scale
140179 # failed the test
141- return (input . to ( torch . float32 ) - zero_point ) * scale
180+ return _dequantize_per_tensor_impl (input , scale , zero_point , quant_min , quant_max , dtype )
142181 else :
143182 raise ValueError (f"Unsupported dtype in dequantize_per_tensor: { dtype } " )
144183
145184
185+ @register_decomposition (torch .ops .quantized_decomposed .dequantize_per_tensor )
186+ def dequantize_per_tensor_decomp_impl (
187+ input : torch .Tensor ,
188+ scale : float ,
189+ zero_point : int ,
190+ quant_min : int ,
191+ quant_max : int ,
192+ dtype : torch .dtype ,
193+ ) -> torch .Tensor :
194+ return _dequantize_per_tensor_impl (input , scale , zero_point , quant_min , quant_max , dtype )
195+
146196quantized_decomposed_lib .define (
147197 "dequantize_per_tensor.tensor(Tensor input, Tensor scale, Tensor zero_point, "
148198 "int quant_min, int quant_max, ScalarType dtype) -> Tensor" )
@@ -163,23 +213,26 @@ def dequantize_per_tensor_tensor(
163213 """
164214 assert zero_point .numel () == 1 , f"Exepecting zero_point tensor to be one element, but received : { zero_point .numel ()} "
165215 assert scale .numel () == 1 , f"Exepecting scale tensor to be one element, but received : { scale .numel ()} "
166- return dequantize_per_tensor (input , scale .item (), zero_point .item (), quant_min , quant_max , dtype )
167-
168- @impl (quantized_decomposed_lib , "dequantize_per_tensor.tensor" , "Meta" )
169- def dequantize_per_tensor_tensor_meta (input , scale , zero_point , quant_min , quant_max , dtype ):
170- assert zero_point .numel () == 1 , f"Exepecting zero_point tensor to be one element, but received : { zero_point .numel ()} "
171- assert scale .numel () == 1 , f"Exepecting scale tensor to be one element, but received : { scale .numel ()} "
172- assert input .dtype == dtype , f"Expecting input to have dtype: { dtype } "
173- if dtype in [torch .uint8 , torch .int8 , torch .int32 ]:
174- return torch .empty_like (input , dtype = torch .float32 )
175- else :
176- raise ValueError (f"Unsupported dtype in dequantize_per_tensor: { dtype } " )
177-
216+ return _dequantize_per_tensor_impl (
217+ input , scale .item (), zero_point .item (), quant_min , quant_max , dtype ) # type: ignore[arg-type]
178218
179219quantized_decomposed_lib .define (
180220 "choose_qparams.tensor(Tensor input, int quant_min, int quant_max, "
181221 "ScalarType dtype) -> (Tensor, Tensor)" )
182222
223+
224+ @register_decomposition (torch .ops .quantized_decomposed .dequantize_per_tensor .tensor )
225+ def dequantize_per_tensor_tensor_decomp_impl (
226+ input : torch .Tensor ,
227+ scale : torch .Tensor ,
228+ zero_point : torch .Tensor ,
229+ quant_min : int ,
230+ quant_max : int ,
231+ dtype : torch .dtype ,
232+ ) -> torch .Tensor :
233+ return _dequantize_per_tensor_impl (
234+ input , scale .item (), zero_point .item (), quant_min , quant_max , dtype ) # type: ignore[arg-type]
235+
183236@impl (quantized_decomposed_lib , "choose_qparams.tensor" , "CompositeExplicitAutograd" )
184237def choose_qparams_tensor (
185238 input : torch .Tensor ,
0 commit comments