Skip to content

Commit 6a1eb3f

Browse files
author
root
committed
Update on "[JIT] add support for overloading functions"
This is a series of PRs that will allow us to support adding [padding to conv](#22484) and also reduce the friction of adding method overloads that was brought up in #23266. Support for overloaded functions following the specification in [PEP 484](https://www.python.org/dev/peps/pep-0484/#function-method-overloading). The usage is: ``` @torch.jit.overload def add(x: int, y: int) -> int: ... @torch.jit.overload def add(x: float, y: float) -> float: ... def add: return x + y ``` Follow up PRs: - Add same API for methods - A couple of cleanups for functions: - don't require default params specified on the overload as well - potentially error if invocation could be matched to multiple overloads. now it just chooses the first one, mypy does the same thing currently Differential Revision: [D16694863](https://our.internmc.facebook.com/intern/diff/D16694863)
2 parents 6f67501 + 8fc496f commit 6a1eb3f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

70 files changed

+3169
-1020
lines changed

.circleci/scripts/binary_linux_test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ pkg="/final_pkgs/\$(ls /final_pkgs)"
2626
if [[ "$PACKAGE_TYPE" == conda ]]; then
2727
conda install -y "\$pkg" --offline
2828
if [[ "$DESIRED_CUDA" == 'cpu' ]]; then
29-
conda install -y cpu-only -c pytorch
29+
conda install -y cpuonly -c pytorch
3030
fi
3131
retry conda install -yq future numpy protobuf six
3232
if [[ "$DESIRED_CUDA" != 'cpu' ]]; then

.circleci/scripts/binary_populate_env.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,9 @@ fi
5353
# We put this here so that OVERRIDE_PACKAGE_VERSION below can read from it
5454
export DATE="$(date -u +%Y%m%d)"
5555
if [[ "$(uname)" == 'Darwin' ]] || [[ "$DESIRED_CUDA" == "cu100" ]]; then
56-
export PYTORCH_BUILD_VERSION="1.2.0.dev$DATE"
56+
export PYTORCH_BUILD_VERSION="1.3.0.dev$DATE"
5757
else
58-
export PYTORCH_BUILD_VERSION="1.2.0.dev$DATE+$DESIRED_CUDA"
58+
export PYTORCH_BUILD_VERSION="1.3.0.dev$DATE+$DESIRED_CUDA"
5959
fi
6060
export PYTORCH_BUILD_NUMBER=1
6161

@@ -72,7 +72,7 @@ export BUILD_PYTHONLESS="${BUILD_PYTHONLESS:-}"
7272
export DESIRED_DEVTOOLSET="$DESIRED_DEVTOOLSET"
7373
7474
export DATE="$DATE"
75-
export NIGHTLIES_DATE_PREAMBLE=1.2.0.dev
75+
export NIGHTLIES_DATE_PREAMBLE=1.3.0.dev
7676
export PYTORCH_BUILD_VERSION="$PYTORCH_BUILD_VERSION"
7777
export PYTORCH_BUILD_NUMBER="$PYTORCH_BUILD_NUMBER"
7878
export OVERRIDE_PACKAGE_VERSION="$PYTORCH_BUILD_VERSION"

aten/src/ATen/core/jit_type.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1627,6 +1627,16 @@ struct CAFFE2_API ClassType : public NamedType {
16271627
// These variants are not registered in the global class table.
16281628
ClassTypePtr refine(at::ArrayRef<TypePtr> refined_slots) const;
16291629

1630+
TypePtr createWithContained(std::vector<TypePtr> contained_types) const override {
1631+
auto ptr = ClassType::create(name_, compilation_unit_);
1632+
AT_ASSERT(numAttributes() == contained_types.size());
1633+
for(size_t i = 0; i < attributeNames_.size(); ++i) {
1634+
AT_ASSERT(attributeTypes_[i]->isSubtypeOf(contained_types[i]));
1635+
ptr->addAttribute(attributeNames_[i], contained_types[i]);
1636+
}
1637+
return ptr;
1638+
}
1639+
16301640
bool is_module() const {
16311641
return bool(parameterSlots_);
16321642
}

aten/src/ATen/native/QuantizedLinear.cpp

Lines changed: 117 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,14 @@ namespace caffe2 {
2525
CAFFE_KNOWN_TYPE(fbgemm::PackBMatrix<int8_t>);
2626
CAFFE_KNOWN_TYPE(fbgemm::PackedGemmMatrixFP16);
2727
#endif // USE_FBGEMM
28-
}
28+
} // namespace caffe2
2929

3030
namespace at {
3131
namespace native {
3232

3333
#ifdef USE_FBGEMM
3434

35-
Tensor fbgemm_linear_int8_weight(
35+
Tensor fbgemm_linear_int8_weight_fp32_activation(
3636
const Tensor& input,
3737
const Tensor& weight,
3838
const Tensor& packed,
@@ -70,13 +70,14 @@ Tensor fbgemm_linear_int8_weight(
7070
// Input tensor is quantized as 8-bit unsigned values
7171
static constexpr int precision = 8;
7272
static constexpr bool is_signed = false;
73+
static constexpr int bound = (1 << (precision - 1));
7374

7475
// Calculate scale and zero point for quantization of input tensor
7576
auto q_params = fbgemm::ChooseQuantizationParams(
7677
/*min=*/x_min,
7778
/*max=*/x_max,
78-
/*qmin=*/is_signed ? -(1 << (precision - 1)) : 0,
79-
/*qmax=*/is_signed ? ((1 << (precision - 1)) - 1) : (1 << precision) - 1,
79+
/*qmin=*/is_signed ? -bound : 0,
80+
/*qmax=*/is_signed ? (bound - 1) : (1 << precision) - 1,
8081
/*preserve_sparsity=*/false);
8182

8283
q_params.precision = precision;
@@ -119,7 +120,7 @@ Tensor fbgemm_linear_int8_weight(
119120
// 1) Add in row and column offsets to the rows and columns, respectively
120121
// 2) Dequantize the results into floating point
121122
// 3) Add in the bias term
122-
fbgemm::ReQuantizeForFloat<false /* FUSE_RELU*/> outputProcObj(
123+
fbgemm::ReQuantizeForFloat</*FUSE_RELU*/false> outputProcObj(
123124
/*nextop=*/doNothingObj,
124125
/*Aq_scale=*/q_params.scale,
125126
/*Bq_scale=*/&weight_scale_float,
@@ -128,10 +129,11 @@ Tensor fbgemm_linear_int8_weight(
128129
/*row_offsets=*/packA.getRowOffsetBuffer(),
129130
/*col_offsets=*/col_offsets.data<int32_t>(),
130131
/*bias=*/bias_contig.data<float>(),
131-
/*ncol=*/N);
132+
/*nCol=*/N);
132133

133134
// Allocate output Tensor and a buffer for fbgemmPacked to use
134-
auto output = at::zeros({M, N}, bias.options().dtype(at::kFloat));
135+
auto output = at::zeros(
136+
{M, N}, bias.options().dtype(at::kFloat));
135137
auto buffer = at::zeros_like(output, output.options().dtype(at::kInt));
136138

137139
// Pull out the PackBMatrix instance from the owning tensor
@@ -155,11 +157,33 @@ Tensor fbgemm_linear_int8_weight(
155157
return output.view(out_sizes);
156158
}
157159

160+
Tensor fbgemm_linear_int8_weight(
161+
const Tensor& input,
162+
const Tensor& weight,
163+
const Tensor& packed,
164+
const Tensor& col_offsets,
165+
Scalar weight_scale,
166+
Scalar weight_zero_point,
167+
const Tensor& bias) {
168+
TORCH_WARN(
169+
"fbgemm_linear_int8_weight will be deprecated soon."
170+
"Please use fbgemm_linear_int8_weight_fp32_activation instead.");
171+
172+
return at::native::fbgemm_linear_int8_weight_fp32_activation(
173+
input,
174+
weight,
175+
packed,
176+
col_offsets,
177+
weight_scale,
178+
weight_zero_point,
179+
bias);
180+
}
181+
158182
namespace {
159183
// Calculate the column offsets
160184
// Note this includes the sum of the columns as well as the scalar term
161-
// B_zero_point * K, whereas the row_offsets created by PackAWithQuantRowOffset
162-
// is only the sum of the A rows.
185+
// B_zero_point * K, whereas the row_offsets created by
186+
// PackAWithQuantRowOffset is only the sum of the A rows.
163187
void calc_col_offsets_transpose(
164188
int K,
165189
int N,
@@ -195,11 +219,12 @@ std::tuple<Tensor, Tensor, double, int64_t> fbgemm_linear_quantize_weight(
195219
// Choose parameters for quantizing the weight as 8-bit signed integer
196220
static constexpr bool is_signed = true;
197221
static constexpr int precision = 8;
222+
static constexpr int bound = (1 << (precision - 1));
198223
auto q_params = fbgemm::ChooseQuantizationParams(
199224
/*min=*/w_min,
200225
/*max=*/w_max,
201-
/*qmin=*/is_signed ? -(1 << (precision - 1)) : 0,
202-
/*qmax=*/is_signed ? ((1 << (precision - 1)) - 1) : (1 << precision) - 1,
226+
/*qmin=*/is_signed ? -bound : 0,
227+
/*qmax=*/is_signed ? (bound - 1) : (1 << precision) - 1,
203228
/*preserve_sparsity=*/false);
204229

205230
q_params.precision = precision;
@@ -230,14 +255,13 @@ bool fbgemm_is_cpu_supported() {
230255
return fbgemm::fbgemmSupportedCPU();
231256
}
232257

233-
Tensor fbgemm_pack_quantized_matrix(
234-
const Tensor& weight,
235-
int64_t K,
236-
int64_t N) {
258+
Tensor fbgemm_pack_quantized_matrix(const Tensor& weight) {
237259
// We make a strong guarantee that models using these operators will have the
238260
// same numerics across different machines. Therefore, we do not provide a
239261
// fallback path and rather fail loudly if we cannot run FBGEMM.
240262
TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU doesn't support FBGEMM.");
263+
int64_t K = weight.size(1);
264+
int64_t N = weight.size(0);
241265
auto weight_contig = weight.contiguous();
242266
auto contiguous_ptr = weight_contig.data<int8_t>();
243267
auto ptr = guts::make_unique<fbgemm::PackBMatrix<int8_t>>(
@@ -251,8 +275,18 @@ Tensor fbgemm_pack_quantized_matrix(
251275
return cpp_custom_type_hack::create(std::move(ptr), weight.options());
252276
}
253277

254-
float raw_uint16_to_fp16(unsigned short value)
255-
{
278+
Tensor fbgemm_pack_quantized_matrix(
279+
const Tensor& weight,
280+
int64_t K,
281+
int64_t N) {
282+
TORCH_WARN(
283+
"fbgemm_pack_quantized_matrix(weight, K, N) will be deprecated soon."
284+
"Please use fbgemm_pack_quantized_matrix(weight) instead.");
285+
286+
return at::native::fbgemm_pack_quantized_matrix(weight);
287+
}
288+
289+
float raw_uint16_to_fp16(unsigned short value) {
256290
// Convert raw 16 bits half precision floating point number
257291
// to single precision floating point number.
258292
unsigned short sign_bits = value >> 15;
@@ -284,7 +318,7 @@ bool check_and_saturate(T* element, T MAX) {
284318
// number will be saturated to max or min representable values by FP16.
285319
void handle_weights_saturation(float* weight, int64_t length) {
286320
float FP16_MAX = raw_uint16_to_fp16(0x7BFF);
287-
bool found_out_of_range = false;
321+
bool found_out_of_range = false;
288322

289323
for (int i = 0; i < length; ++i) {
290324
if (check_and_saturate<float>(&weight[i], FP16_MAX)) {
@@ -297,8 +331,7 @@ void handle_weights_saturation(float* weight, int64_t length) {
297331
}
298332
}
299333

300-
Tensor fbgemm_pack_gemm_matrix_fp16(
301-
const Tensor& weight ) {
334+
Tensor fbgemm_pack_gemm_matrix_fp16(const Tensor& weight) {
302335
// We make a strong guarantee that models using these operators will have the
303336
// same numerics across different machines. Therefore, we do not provide a
304337
// fallback path and rather fail loudly if we cannot run FBGEMM.
@@ -309,7 +342,7 @@ Tensor fbgemm_pack_gemm_matrix_fp16(
309342
Tensor weight_contig = weight.contiguous();
310343
auto weight_contig_ptr = weight_contig.data<float>();
311344

312-
handle_weights_saturation(weight_contig_ptr, K*N);
345+
handle_weights_saturation(weight_contig_ptr, K * N);
313346

314347
// TODO(mingzhe09088):
315348
// Consider using a functor here in PackedGemmMatrixFP16
@@ -319,15 +352,11 @@ Tensor fbgemm_pack_gemm_matrix_fp16(
319352
// within this translation unit. It might be very problematic if that tensor
320353
// flows across dll boundaries.
321354
auto ptr = guts::make_unique<fbgemm::PackedGemmMatrixFP16>(
322-
fbgemm::matrix_op_t::Transpose,
323-
K,
324-
N,
325-
1,
326-
weight_contig_ptr);
355+
fbgemm::matrix_op_t::Transpose, K, N, 1, weight_contig_ptr);
327356
return cpp_custom_type_hack::create(std::move(ptr), weight.options());
328357
}
329358

330-
Tensor fbgemm_linear_fp16_weight(
359+
Tensor fbgemm_linear_fp16_weight_fp32_activation(
331360
const Tensor& input,
332361
const Tensor& packed_weight,
333362
const Tensor& bias) {
@@ -358,7 +387,7 @@ Tensor fbgemm_linear_fp16_weight(
358387
M,
359388
input_ptr,
360389
packed_weight_fp16,
361-
0.f,
390+
0.0f,
362391
output.data<float>());
363392

364393
// Add bias term
@@ -369,8 +398,35 @@ Tensor fbgemm_linear_fp16_weight(
369398
return output.view(out_sizes);
370399
}
371400

401+
Tensor fbgemm_linear_fp16_weight(
402+
const Tensor& input,
403+
const Tensor& packed_weight,
404+
const Tensor& bias) {
405+
TORCH_WARN(
406+
"fbgemm_linear_fp16_weight will be deprecated soon."
407+
"Please use fbgemm_linear_fp16_weight_fp32_activation instead.");
408+
409+
return at::native::fbgemm_linear_fp16_weight_fp32_activation(
410+
input, packed_weight, bias);
411+
}
412+
372413
#else // USE_FBGEMM
373414

415+
Tensor fbgemm_linear_int8_weight_fp32_activation(
416+
const Tensor& /*input*/,
417+
const Tensor& /*weight*/,
418+
const Tensor& /*packed*/,
419+
const Tensor& /*col_offsets*/,
420+
Scalar /*weight_scale*/,
421+
Scalar /*weight_zero_point*/,
422+
const Tensor& /*bias*/) {
423+
// We make a strong guarantee that models using these operators will have the
424+
// same numerics across different machines. Therefore, we do not provide a
425+
// fallback path and rather fail loudly if we cannot run FBGEMM.
426+
TORCH_CHECK(
427+
false, "This PyTorch installation was not built with FBGEMM operators");
428+
}
429+
374430
Tensor fbgemm_linear_int8_weight(
375431
const Tensor& /*input*/,
376432
const Tensor& /*weight*/,
@@ -379,6 +435,10 @@ Tensor fbgemm_linear_int8_weight(
379435
Scalar /*weight_scale*/,
380436
Scalar /*weight_zero_point*/,
381437
const Tensor& /*bias*/) {
438+
TORCH_WARN(
439+
"fbgemm_linear_int8_weight will be deprecated soon."
440+
"Please use fbgemm_linear_int8_weight_fp32_activation instead.");
441+
382442
// We make a strong guarantee that models using these operators will have the
383443
// same numerics across different machines. Therefore, we do not provide a
384444
// fallback path and rather fail loudly if we cannot run FBGEMM.
@@ -395,19 +455,41 @@ std::tuple<Tensor, Tensor, double, int64_t> fbgemm_linear_quantize_weight(
395455
false, "This PyTorch installation was not built with FBGEMM operators");
396456
}
397457

458+
Tensor fbgemm_pack_quantized_matrix(const Tensor& /*input*/) {
459+
// We make a strong guarantee that models using these operators will have the
460+
// same numerics across different machines. Therefore, we do not provide a
461+
// fallback path and rather fail loudly if we cannot run FBGEMM.
462+
TORCH_CHECK(
463+
false, "This PyTorch installation was not built with FBGEMM operators");
464+
}
465+
398466
Tensor fbgemm_pack_quantized_matrix(
399467
const Tensor& /*input*/,
400468
int64_t /*K*/,
401469
int64_t /*N*/) {
470+
TORCH_WARN(
471+
"fbgemm_pack_quantized_matrix(weight, K, N) will be deprecated soon."
472+
"Please use fbgemm_pack_quantized_matrix(weight) instead.");
473+
402474
// We make a strong guarantee that models using these operators will have the
403475
// same numerics across different machines. Therefore, we do not provide a
404476
// fallback path and rather fail loudly if we cannot run FBGEMM.
405477
TORCH_CHECK(
406478
false, "This PyTorch installation was not built with FBGEMM operators");
407479
}
408480

409-
Tensor fbgemm_pack_gemm_matrix_fp16(
410-
const Tensor& weight) {
481+
Tensor fbgemm_pack_gemm_matrix_fp16(const Tensor& weight) {
482+
// We make a strong guarantee that models using these operators will have the
483+
// same numerics across different machines. Therefore, we do not provide a
484+
// fallback path and rather fail loudly if we cannot run FBGEMM.
485+
TORCH_CHECK(
486+
false, "This PyTorch installation was not built with FBGEMM operators");
487+
}
488+
489+
Tensor fbgemm_linear_fp16_weight_fp32_activation(
490+
const Tensor& input,
491+
const Tensor& packed_weight,
492+
const Tensor& bias) {
411493
// We make a strong guarantee that models using these operators will have the
412494
// same numerics across different machines. Therefore, we do not provide a
413495
// fallback path and rather fail loudly if we cannot run FBGEMM.
@@ -419,6 +501,10 @@ Tensor fbgemm_linear_fp16_weight(
419501
const Tensor& input,
420502
const Tensor& packed_weight,
421503
const Tensor& bias) {
504+
TORCH_WARN(
505+
"fbgemm_linear_fp16_weight will be deprecated soon."
506+
"Please use fbgemm_linear_fp16_weight_fp32_activation instead.");
507+
422508
// We make a strong guarantee that models using these operators will have the
423509
// same numerics across different machines. Therefore, we do not provide a
424510
// fallback path and rather fail loudly if we cannot run FBGEMM.
@@ -431,5 +517,5 @@ bool fbgemm_is_cpu_supported() {
431517
}
432518

433519
#endif // USE_FBGEMM
434-
}
520+
} // namespace native
435521
} // namespace at

aten/src/ATen/native/RNN.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,11 +129,11 @@ struct QuantizedCellParams {
129129
TORCH_CHECK(false, "matmul is not supported with quantized cell params");
130130
}
131131
Tensor linear_ih(Tensor input) const {
132-
return at::fbgemm_linear_int8_weight(
132+
return at::fbgemm_linear_int8_weight_fp32_activation(
133133
input, w_ih, packed_ih, col_offsets_ih, scale_ih, zero_point_ih, b_ih);
134134
}
135135
Tensor linear_hh(Tensor h) const {
136-
return at::fbgemm_linear_int8_weight(
136+
return at::fbgemm_linear_int8_weight_fp32_activation(
137137
h, w_hh, packed_hh, col_offsets_hh, scale_hh, zero_point_hh, b_hh);
138138
}
139139
};

aten/src/ATen/native/TensorShape.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,13 +403,23 @@ Tensor repeat(const Tensor& self, IntArrayRef repeats) {
403403
std::vector<int64_t> padded_size(num_new_dimensions, 1);
404404
padded_size.insert(padded_size.end(), self.sizes().begin(), self.sizes().end());
405405
std::vector<int64_t> target_size(repeats.size());
406+
bool zero_tensor = false;
406407
for(size_t idx = 0; idx < repeats.size(); ++idx) {
408+
if (repeats[idx] == 0) {
409+
zero_tensor = true;
410+
}
407411
target_size[idx] = padded_size[idx] * repeats[idx];
408412
}
409413

410414
Tensor xtensor = self.expand(padded_size);
411415

412416
Tensor result = at::empty(target_size, self.options());
417+
418+
// return an empty tensor if one of the repeat dimensions is zero
419+
if (zero_tensor) {
420+
return result;
421+
}
422+
413423
Tensor urtensor = at::alias(result);
414424
for (int64_t i = 0; i < xtensor.dim(); ++i) {
415425
// can't unfold with step 0, so make sure step is at least 1

0 commit comments

Comments
 (0)