Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3615,6 +3615,10 @@
use_c10_dispatcher: full
variants: function

- func: _saturate_weight_to_fp16(Tensor weight) -> Tensor
use_c10_dispatcher: full
variants: function

# to(Device) must not exist because all constructors of Device also works for
# TensorOptions. Otherwise, an ambiguity error is thrown.
# See NOTE [ TensorOptions Constructors ].
Expand Down
58 changes: 11 additions & 47 deletions aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <ATen/native/quantized/cpu/init_qnnpack.h>
#include <ATen/native/quantized/cpu/packed_params.h>
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
#include <ATen/native/quantized/cpu/quant_utils.h>
#include <ATen/quantized/Quantizer.h>
#include <torch/custom_class.h>
#include <torch/library.h>
Expand Down Expand Up @@ -163,60 +164,16 @@ c10::intrusive_ptr<LinearPackedParamsBase> PackedLinearWeightsQnnp::prepack(
#endif // USE_PYTORCH_QNNPACK

#ifdef USE_FBGEMM
namespace {
float RawUint16ToFp16(unsigned short value) {
// Convert raw 16 bits half precision floating point number
// to single precision floating point number.
const unsigned short sign_bits = value >> 15;
const unsigned short exponent_bits = value >> 10 & 0x1f;
const unsigned short significand_bits = value & 0x3ff;

const float sign = sign_bits ? -1 : 1;
const float significand =
1 + significand_bits * 0.0009765625f; // 0.0009765625f = 0x1p-10 = 2^-10;
const float exponent = exponent_bits - 0xf;

return sign * std::ldexp(significand, exponent);
}

template <typename T>
bool CheckAndSaturate(T max_val, T* element) {
if (*element > max_val) {
*element = max_val;
return true;
}
if (*element < -max_val) {
*element = -max_val;
return true;
}
return false;
}

// The range for using FP16 quantization of weights requires that the elements
// should be in the range of [5.96e-8, 65504]. If it is out of range, then the
// number will be saturated to max or min representable values by FP16.
void HandleWeightsSaturation(int64_t N, float* weight) {
const float kFp16Max = RawUint16ToFp16(0x7BFF);
bool found_out_of_range = false;
for (int64_t i = 0; i < N; ++i) {
if (CheckAndSaturate<float>(kFp16Max, weight + i)) {
found_out_of_range = true;
}
}
if (found_out_of_range) {
TORCH_WARN("FOUND weight out of range ");
}
}
} // namespace

c10::intrusive_ptr<LinearPackedParamsBase> PackedLinearWeightFp16::prepack(
at::Tensor weight,
c10::optional<at::Tensor> bias) {

weight = at::_saturate_weight_to_fp16(weight);

const int64_t K = weight.size(1);
const int64_t N = weight.size(0);
at::Tensor weight_contig = weight.contiguous();
float* weight_contig_ptr = weight_contig.data_ptr<float>();
HandleWeightsSaturation(K * N, weight_contig_ptr);

// TODO(mingzhe09088):
// Consider using a functor here in PackedGemmMatrixFP16
Expand All @@ -235,6 +192,13 @@ c10::intrusive_ptr<LinearPackedParamsBase> PackedLinearWeightFp16::prepack(

namespace at {
namespace native {

at::Tensor _saturate_weight_to_fp16(const Tensor& weight) {
float* weight_contig_ptr = weight.contiguous().data_ptr<float>();
quant_utils::HandleWeightsSaturation(weight.size(0) * weight.size(1), weight_contig_ptr);
return weight;
}

namespace {

class QLinearPackWeightInt8 final {
Expand Down
47 changes: 47 additions & 0 deletions aten/src/ATen/native/quantized/cpu/quant_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,36 @@
#include <cmath>

namespace quant_utils {
namespace {
float RawUint16ToFp16(unsigned short value) {
// Convert raw 16 bits half precision floating point number
// to single precision floating point number.
const unsigned short sign_bits = value >> 15;
const unsigned short exponent_bits = value >> 10 & 0x1f;
const unsigned short significand_bits = value & 0x3ff;

const float sign = sign_bits ? -1 : 1;
const float significand =
1 + significand_bits * 0.0009765625f; // 0.0009765625f = 0x1p-10 = 2^-10;
const float exponent = exponent_bits - 0xf;

return sign * std::ldexp(significand, exponent);
}

template <typename T>
bool CheckAndSaturate(T max_val, T* element) {
if (*element > max_val) {
*element = max_val;
return true;
}
if (*element < -max_val) {
*element = -max_val;
return true;
}
return false;
}
}

using namespace std;
// A structure to hold quantization parameters 'scale' and 'zero_point'.
// The meaning of these values is as the constants in the quantization equation
Expand Down Expand Up @@ -136,4 +166,21 @@ static torch::List<int64_t> MakeArgForConv1d(const torch::List<int64_t>& arg,
return result;
}


// The range for using FP16 quantization of weights requires that the elements
// should be in the range of [5.96e-8, 65504]. If it is out of range, then the
// number will be saturated to max or min representable values by FP16.
inline void HandleWeightsSaturation(int64_t N, float* weight) {
const float kFp16Max = RawUint16ToFp16(0x7BFF);
bool found_out_of_range = false;
for (int64_t i = 0; i < N; ++i) {
if (CheckAndSaturate<float>(kFp16Max, weight + i)) {
found_out_of_range = true;
}
}
if (found_out_of_range) {
TORCH_WARN("FOUND weight out of range ");
}
}

} // namespace quant_utils
20 changes: 11 additions & 9 deletions test/quantization/test_quantize_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2757,8 +2757,7 @@ def forward(self, x):

m = torch.jit.script(M())
m = quantize_dynamic_jit(m, {'': float16_dynamic_qconfig}, debug=True)
FileCheck().check("aten::to") \
.check_next("aten::to") \
FileCheck().check("aten::_saturate_weight_to_fp16") \
.check("aten::linear") \
.check_not("aten::dequantize") \
.check_not("aten::quantize") \
Expand Down Expand Up @@ -3038,13 +3037,16 @@ def test_single_linear_dynamic(self):
@skipIfNoFBGEMM
def test_linear_dynamic_fp16(self):
linear_model = SingleLayerLinearModel().eval()
# Create weight tensor values that are beyond fp16 max
x = torch.ones(5, 5) * 65532
linear_model.fc1.weight = torch.nn.Parameter(x)

model_eager = quantize_dynamic(linear_model, dtype=torch.float16)
result_eager = model_eager(self.calib_data[0][0])
model_script = torch.jit.script(linear_model)
model_traced = torch.jit.trace(linear_model, self.calib_data[0][0])
qconfig_dict = {'' : float16_dynamic_qconfig}

for model in [model_traced, model_script]:
model_quantized = quantize_dynamic_jit(model, qconfig_dict, debug=False)
# TODO check model with debug=True matches quantized model result
self.assertEqual(model_quantized(self.calib_data[0][0]), result_eager)
for trace in [True, False]:
quantized_model = self.checkGraphModeOp(linear_model, self.calib_data[0][0],
"quantized::linear_dynamic_fp16", tracing=trace,
dynamic=True, qconfig=float16_dynamic_qconfig)
# compare result with eager mode
self.assertEqual(quantized_model(self.calib_data[0][0]), result_eager)
25 changes: 0 additions & 25 deletions torch/csrc/jit/passes/quantization/helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -694,30 +694,5 @@ bool is_batchnorm3d_module(
"__torch__.torch.nn.modules.batchnorm.BatchNorm3d");
}

bool is_half_dtype(
const Match& match,
const std::unordered_map<std::string, Value*>& vmap) {
const auto& match_vmap = match.values_map;
auto fp16_type = toIValue(match_vmap.at(vmap.at("dtype_fp16")));
return (fp16_type->toScalarType() == c10::kHalf);
}

bool is_float_dtype(
const Match& match,
const std::unordered_map<std::string, Value*>& vmap) {
const auto& match_vmap = match.values_map;
auto fp16_type = toIValue(match_vmap.at(vmap.at("dtype_fp32")));
return (fp16_type->toScalarType() == c10::kFloat);
}

bool is_false_value(
const Match& match,
const std::unordered_map<std::string, Value*>& vmap) {
const auto& match_vmap = match.values_map;

auto default_param = toIValue(match_vmap.at(vmap.at("false")));
return default_param->toBool() == false;
}

} // namespace jit
} // namespace torch
11 changes: 0 additions & 11 deletions torch/csrc/jit/passes/quantization/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,16 +174,5 @@ bool is_batchnorm3d_module(
const Match& match,
const std::unordered_map<std::string, Value*>& vmap);

bool is_half_dtype(
const Match& match,
const std::unordered_map<std::string, Value*>& vmap);

bool is_float_dtype(
const Match& match,
const std::unordered_map<std::string, Value*>& vmap);

bool is_false_value(
const Match& match,
const std::unordered_map<std::string, Value*>& vmap);
} // namespace jit
} // namespace torch
40 changes: 17 additions & 23 deletions torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,28 +191,14 @@ DynamicQuantOps insertChooseQParamQuantDequant(
}

Node* insertFP16CastOps(Graph* graph, Value* observer_out) {
auto default_false = graph->insertConstant(false);
Value* none = graph->insertConstant(IValue());
Value* fp16_dtype = graph->insertConstant(IValue(c10::kHalf));
Value* float_dtype = graph->insertConstant(IValue(c10::kFloat));

std::vector<Value*> input_to_fp16 = {observer_out,
fp16_dtype,
/* non_blocking */ default_false,
/* copy */ default_false};
Node* cast_to_fp16 = graph->create(Symbol::aten("to"), input_to_fp16);
graph->insertNode(cast_to_fp16);

auto fp16_out = cast_to_fp16->output();
std::vector<Value*> input_to_fp32 = {fp16_out,
float_dtype,
/* non_blocking */ default_false,
/* copy */ default_false};
Node* cast_to_fp32 = graph->create(Symbol::aten("to"), input_to_fp32);
graph->insertNode(cast_to_fp32);
// If the weight value is outside of the range for FP16 range, i.e. [5.96e-8,
// 65504], we saturate the values to the min/max of this range.
Node* saturated_weight =
graph->create(Symbol::aten("_saturate_weight_to_fp16"), {observer_out});
graph->insertNode(saturated_weight);
graph->lint();

return cast_to_fp32;
return saturated_weight;
}

bool isNoopObserver(Value* observer) {
Expand Down Expand Up @@ -246,8 +232,15 @@ void insertQuantizationOps(
}
Value* original_val = observer->input(1);
Node *quant, *choose_qparams, *dequant;
bool has_dequant = true;
if (quant_type == QuantType::DYNAMIC && isNoopObserver(observer->input(0))) {
dequant = insertFP16CastOps(g, observer_out);
// We don't need to insert cast operators for activation tensors for fp16
// quant.
Comment on lines +237 to +238
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we filter this in a different place? e.g. we don't insert observer for activation tensors?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible for user to not specify activation observer in the qconfig? Does the prepare_jit pass ensure observers aren't inserted in that case for activation tensors?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That may have issues since for FP16 quant we don't specify dtype anywhere in the qconfig. We set quant type to dynamic so there is no way to distinguish int8 dynamic quant from fp16 dynamic quant. Hence was wondering if not specifying any activation observer (since we don't want it observed) would work here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like right now we are checking noop observer to do fp16 quantization, this sounds like a hack, can we expose fp16 as an argument to the API?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i feel it makes more sense to expose this in the top level API, why don't we do that?

if (isWeight(module, observer_out)) {
dequant = insertFP16CastOps(g, observer_out);
} else {
has_dequant = false;
}
} else if (
quant_type == QuantType::DYNAMIC && !isWeight(module, observer_out)) {
Value* dtype = g->insertGetAttr(self, qparam_names.back());
Expand All @@ -269,8 +262,9 @@ void insertQuantizationOps(
dequant = insertDeQuant(g, quant->output(), original_val);
}
observer_out->replaceAllUsesWith(original_val);

original_val->replaceAllUsesAfterNodeWith(dequant, dequant->output());
if (has_dequant) {
original_val->replaceAllUsesAfterNodeWith(dequant, dequant->output());
}
}

// find the observer for Value `v` and return the name of the observer
Expand Down
23 changes: 9 additions & 14 deletions torch/csrc/jit/passes/quantization/quantization_patterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -1006,24 +1006,21 @@ graph(%packed_params, %a, %reduce_range, %a_dtype):
return (%r) )";

std::string linear_dynamic_fp16 = R"(
graph(%packed_params, %a, %dtype_fp16, %dtype_fp32, %false):
%fp16_tensor = aten::to(%a, %dtype_fp16, %false, %false)
%fp32_tensor = aten::to(%fp16_tensor, %dtype_fp32, %false, %false)
graph(%packed_params, %a):
%w_unpacked : Tensor, %b : Tensor? = quantized::linear_unpack_fp16(%packed_params)
%r = aten::linear(%fp32_tensor, %w_unpacked, %b)
%r = aten::linear(%a, %w_unpacked, %b)
return (%r) )";

std::string quantized_linear_dynamic_fp16 = R"(
graph(%packed_params, %a, %dtype_fp16, %dtype_fp32, %false):
graph(%packed_params, %a):
%r = quantized::linear_dynamic_fp16(%a, %packed_params)
return (%r) )";

return {
{"quantized::linear_dynamic", linear_dynamic, quantized_linear_dynamic},
{"quantized::linear_dynamic_fp16",
linear_dynamic_fp16,
quantized_linear_dynamic_fp16,
{is_half_dtype, is_float_dtype, is_false_value}},
quantized_linear_dynamic_fp16},
};
}

Expand All @@ -1042,13 +1039,12 @@ graph(%a_dequant, %w_quant, %b):
%r = aten::linear(%a_dequant, %w_dequant, %b_unpacked)
return (%r) )";
std::string linear_fp16_with_cast = R"(
graph(%w, %a_dq, %b, %dtype_fp16, %dtype_fp32, %false):
%fp16_tensor = aten::to(%w, %dtype_fp16, %false, %false)
%fp32_tensor = aten::to(%fp16_tensor, %dtype_fp32, %false, %false)
%r = aten::linear(%a_dq, %fp32_tensor, %b)
graph(%w, %a_dq, %b):
%fp16_tensor = aten::_saturate_weight_to_fp16(%w)
%r = aten::linear(%a_dq, %fp16_tensor, %b)
return (%r) )";
std::string linear_fp16_with_prepack = R"(
graph(%w, %a_dq, %b, %dtype_fp16, %dtype_fp32, %false):
graph(%w, %a_dq, %b):
%packed_params = quantized::linear_prepack_fp16(%w, %b)
%w_unpacked : Tensor, %b_unpacked : Tensor? = quantized::linear_unpack_fp16(%packed_params)
%r = aten::linear(%a_dq, %w_unpacked, %b_unpacked)
Expand All @@ -1058,8 +1054,7 @@ graph(%w, %a_dq, %b, %dtype_fp16, %dtype_fp32, %false):
{"linear_prepack_unpack", linear_with_quant, linear_with_quant_prepack},
{"linear_fp16_prepack_unpack",
linear_fp16_with_cast,
linear_fp16_with_prepack,
{is_half_dtype, is_float_dtype, is_false_value}},
linear_fp16_with_prepack},
};
}

Expand Down
5 changes: 3 additions & 2 deletions torch/testing/_internal/common_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,15 +294,16 @@ def _checkModuleCorrectnessAgainstOrig(self, orig_mod, test_mod, calib_data):
self.assertEqual(scripted_output, ref_output)


def checkGraphModeOp(self, module, data, quantized_op, tracing=False, debug=False, check=True, eval_mode=True, dynamic=False):
def checkGraphModeOp(self, module, data, quantized_op, tracing=False, debug=False,
check=True, eval_mode=True, dynamic=False, qconfig=None):
if debug:
print('Testing:', str(module))
qconfig_dict = {'': get_default_qconfig(torch.backends.quantized.engine)}

if eval_mode:
module = module.eval()
if dynamic:
qconfig_dict = {'': default_dynamic_qconfig}
qconfig_dict = {'': default_dynamic_qconfig if qconfig is None else qconfig}
inputs = data
else:
*inputs, target = data[0]
Expand Down