Skip to content

Commit c7454d5

Browse files
author
root
committed
Update on "extend torch.jit._overload to module methods"
Follow up to #23886, add the same overload api specified in PEP 484 to module methods to reduce the friction of adding method overloads that was brought up in #23266. The usage is: ``` @torch.jit.overload def add(self, y: int) -> int: ... @torch.jit.overload def add(self, y: float) -> float: ... def add(): ... ``` Differential Revision: [D16921304](https://our.internmc.facebook.com/intern/diff/D16921304)
2 parents fe5b6b0 + aed306d commit c7454d5

File tree

361 files changed

+8234
-3763
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

361 files changed

+8234
-3763
lines changed

.circleci/scripts/binary_populate_env.sh

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,15 @@ if [[ "$PACKAGE_TYPE" == 'libtorch' ]]; then
2929
fi
3030

3131
# Pick docker image
32-
if [[ "$PACKAGE_TYPE" == conda ]]; then
33-
export DOCKER_IMAGE="soumith/conda-cuda"
34-
elif [[ "$DESIRED_CUDA" == cpu ]]; then
35-
export DOCKER_IMAGE="soumith/manylinux-cuda100"
36-
else
37-
export DOCKER_IMAGE="soumith/manylinux-cuda${DESIRED_CUDA:2}"
32+
export DOCKER_IMAGE=${DOCKER_IMAGE:-}
33+
if [[ -z "$DOCKER_IMAGE" ]]; then
34+
if [[ "$PACKAGE_TYPE" == conda ]]; then
35+
export DOCKER_IMAGE="soumith/conda-cuda"
36+
elif [[ "$DESIRED_CUDA" == cpu ]]; then
37+
export DOCKER_IMAGE="soumith/manylinux-cuda100"
38+
else
39+
export DOCKER_IMAGE="soumith/manylinux-cuda${DESIRED_CUDA:2}"
40+
fi
3841
fi
3942

4043
# Upload to parallel folder for gcc abis

.flake8

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ max-line-length = 120
55
# E501 is not flexible enough, we're using B950 instead
66
ignore =
77
E203,E305,E402,E501,E721,E741,F403,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303,
8+
# EXE001 is skipped for now because some files use shebang to determine Python version.
9+
EXE001,
810
# these ignores are from flake8-bugbear; please fix!
911
B007,B008,
1012
# these ignores are from flake8-comprehensions; please fix!

.github/pytorch-probot.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
tracking_issue: 24422

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
## PyTorch
1010

11+
.coverage
12+
.hypothesis
1113
.mypy_cache
1214
*/*.pyc
1315
*/*.so*
@@ -27,6 +29,7 @@ dist/
2729
docs/src/**/*
2830
docs/cpp/build
2931
docs/cpp/source/api
32+
log
3033
test/.coverage
3134
test/.hypothesis/
3235
test/cpp/api/mnist

.jenkins/pytorch/test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ test_python_nn() {
104104
}
105105

106106
test_python_all_except_nn() {
107-
time python test/run_test.py --exclude nn --verbose
107+
time python test/run_test.py --exclude nn --verbose --bring-to-front quantization quantized_conv quantized quantized_tensor quantized_nn_mods quantizer
108108
assert_git_not_dirty
109109
}
110110

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,8 @@ if (MSVC)
188188
string(APPEND CMAKE_CXX_FLAGS " /EHa")
189189
if(MSVC_Z7_OVERRIDE)
190190
foreach(flag_var
191+
CMAKE_C_FLAGS CMAKE_C_FLAGS_DEBUG CMAKE_C_FLAGS_RELEASE
192+
CMAKE_C_FLAGS_MINSIZEREL CMAKE_C_FLAGS_RELWITHDEBINFO
191193
CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE
192194
CMAKE_CXX_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS_RELWITHDEBINFO)
193195
if(${flag_var} MATCHES "/Z[iI]")

aten/src/ATen/ATen.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@
1717
#include <ATen/Tensor.h>
1818
#include <ATen/TensorGeometry.h>
1919
#include <ATen/TensorOperators.h>
20-
#include <ATen/TensorOptions.h>
2120
#include <ATen/Version.h>
2221
#include <ATen/core/ATenGeneral.h>
2322
#include <ATen/core/Generator.h>
2423
#include <c10/core/Layout.h>
2524
#include <ATen/core/Scalar.h>
2625
#include <c10/core/Storage.h>
26+
#include <c10/core/TensorOptions.h>
2727
#include <ATen/core/Reduction.h>
2828
#include <c10/util/Exception.h>
2929
#include <ATen/core/ATenDispatch.h>

aten/src/ATen/Context.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
#include <ATen/Context.h>
44

5-
#include <ATen/core/TensorOptions.h>
5+
#include <c10/core/TensorOptions.h>
66

77
#include <thread>
88
#include <mutex>

aten/src/ATen/DLConvertor.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
using namespace std;
88
namespace at {
99

10-
static DLDataType getDLDataType(const Tensor& t) {
10+
DLDataType getDLDataType(const Tensor& t) {
1111
DLDataType dtype;
1212
dtype.lanes = 1;
1313
dtype.bits = t.element_size() * 8;
@@ -65,7 +65,7 @@ static DLDataType getDLDataType(const Tensor& t) {
6565
return dtype;
6666
}
6767

68-
static DLContext getDLContext(const Tensor& tensor, const int64_t& device_id) {
68+
DLContext getDLContext(const Tensor& tensor, const int64_t& device_id) {
6969
DLContext ctx;
7070
ctx.device_id = device_id;
7171
if (tensor.is_cuda()) {

aten/src/ATen/DLConvertor.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,7 @@ namespace at {
1313
CAFFE2_API ScalarType toScalarType(const DLDataType& dtype);
1414
CAFFE2_API DLManagedTensor* toDLPack(const Tensor& src);
1515
CAFFE2_API Tensor fromDLPack(const DLManagedTensor* src);
16+
CAFFE2_API DLDataType getDLDataType(const Tensor& t);
17+
CAFFE2_API DLContext getDLContext(const Tensor& tensor, const int64_t& device_id);
1618

1719
} //namespace at

0 commit comments

Comments
 (0)