Skip to content

Commit 4470b57

Browse files
committed
Update on "Delete torch::deploy from pytorch core"
As we have migrated torch::deploy over to https://github.com/pytorch/multipy, we can now delete it from pytorch core as ongoing development will happen there. This PR was created due to syncing issues with #85443 which is where the review history can be found. [ghstack-poisoned]
2 parents 4784a3b + 7fe0c7c commit 4470b57

40 files changed

+1586
-98
lines changed

.github/auto_request_review.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ reviewers:
99
- albanD
1010
- Krovatkin
1111
- miladm
12+
- bdhirsh
1213

1314
per_author:
1415
symbolic-shapes:

.github/ci_commit_pins/xla.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
f2b36df6a1a80137eff7644e6d0f4eeb7ff429d6
1+
d6b1971d4f40364cd6f2d23d047818d295971f7a

aten/src/ATen/ThreadLocalState.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ void ThreadLocalState::set_grad_mode(bool enabled) {
2727
autograd_tls_.set_grad_mode(enabled);
2828
}
2929

30+
void ThreadLocalState::set_multithreading_enabled(bool enabled) {
31+
autograd_tls_.set_multithreading_enabled(enabled);
32+
}
33+
3034
/* static */
3135
void ThreadLocalState::setThreadLocalState(
3236
const ThreadLocalState& state) {

aten/src/ATen/ThreadLocalState.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@ class TORCH_API ThreadLocalState {
3030
// autograd engine.
3131
void set_grad_mode(bool enabled);
3232

33+
// set_multithreading_enabled - force the value of the multithreadinmaximum
34+
// threads TLS in
35+
// the current state object. This is used for example in the
36+
// autograd engine.
37+
void set_multithreading_enabled(bool enabled);
38+
3339
// Sets thread local variables in the current thread,
3440
// according to the thread boundary specified
3541
static void setThreadLocalState(const ThreadLocalState& state);

aten/src/ATen/native/IndexingUtils.h

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,18 @@ static C10_UNUSED std::vector<Tensor> expandTensors(const Tensor & self, IOptTen
4848
return result;
4949
}
5050

51-
static C10_UNUSED void checkIndexTensorTypes(IOptTensorListRef indices) {
51+
static C10_UNUSED void checkIndexTensorTypes(IOptTensorListRef indices, bool allow_int=false) {
5252
for (const auto& tensor : indices) {
5353
if (tensor.has_value() && tensor->defined()) {
5454
auto scalarType = tensor->scalar_type();
55-
if (scalarType != kLong && scalarType != kByte && scalarType != kBool) {
56-
TORCH_CHECK_INDEX(false, "tensors used as indices must be long, byte or bool tensors");
55+
if (allow_int) {
56+
if (scalarType != kLong && scalarType != kByte && scalarType != kBool && scalarType != kInt) {
57+
TORCH_CHECK_INDEX(false, "tensors used as indices must be long, byte or bool tensors");
58+
}
59+
} else {
60+
if (scalarType != kLong && scalarType != kByte && scalarType != kBool) {
61+
TORCH_CHECK_INDEX(false, "tensors used as indices must be long, byte or bool tensors");
62+
}
5763
}
5864
}
5965
}

aten/src/ATen/native/MetaTensor.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ Tensor empty_strided_meta(
3838
c10::optional<Device> device_opt,
3939
c10::optional<bool> pin_memory_opt
4040
) {
41-
return empty_strided_meta_symint(c10::fromIntArrayRef(size), c10::fromIntArrayRef(stride), dtype_opt, layout_opt, device_opt, pin_memory_opt);
41+
return empty_strided_meta_symint(c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), dtype_opt, layout_opt, device_opt, pin_memory_opt);
4242
}
4343

4444
Tensor empty_strided_meta_symint(

aten/src/ATen/native/TensorAdvancedIndexingUtils.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ const Tensor& value){
5757
}
5858

5959
static AdvancedIndex make_info(Tensor self, IOptTensorListRef orig) {
60-
checkIndexTensorTypes(orig);
60+
checkIndexTensorTypes(orig, /*allow_int*/ true);
6161
// first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more LongTensors
6262
auto indices = expandTensors(self, orig);
6363
// next broadcast all index tensors together
@@ -82,6 +82,12 @@ static AdvancedIndex make_info(Tensor self, IOptTensorListRef orig) {
8282
indice = indice.to(self.device());
8383
}
8484
}
85+
for (auto & indice : indices) {
86+
if (indice.defined() && indice.dtype() == at::kInt) {
87+
indice = indice.to(at::kLong);
88+
}
89+
}
90+
8591
return AdvancedIndex(self, indices);
8692
}
8793

aten/src/ATen/native/cpu/HistogramKernel.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,8 @@ void histogramdd_cpu_contiguous(Tensor& hist, const TensorList& bin_edges,
148148
for (const auto dim : c10::irange(D)) {
149149
const input_t elt = accessor_in[i][dim];
150150

151-
// Skips elements which fall outside the specified bins
152-
if (elt < leftmost_edge[dim] || rightmost_edge[dim] < elt) {
151+
// Skips elements which fall outside the specified bins and NaN elements
152+
if (!(elt >= leftmost_edge[dim] && elt <= rightmost_edge[dim])) {
153153
skip_elt = true;
154154
break;
155155
}

aten/src/ATen/native/cuda/Indexing.cu

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,9 +291,14 @@ computeLinearIndex(const Tensor & src, TensorList indices, bool check_range) {
291291

292292

293293
static std::tuple<Tensor, Tensor, int64_t, int64_t, int64_t, std::vector<int64_t>> makeLinearIndex(Tensor self, IOptTensorListRef orig, bool check_range) {
294-
checkIndexTensorTypes(orig);
294+
checkIndexTensorTypes(orig, /*allow_int*/true);
295295
// first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more LongTensors
296296
auto indices = expandTensors(self, orig);
297+
for (auto & i : indices) {
298+
if (i.defined() && i.dtype() == at::kInt) {
299+
i = i.to(at::kLong);
300+
}
301+
}
297302
// next broadcast all index tensors together
298303
indices = expand_outplace(indices);
299304
// add missing null Tensors so that it matches self.dim()

aten/src/ATen/native/metal/ops/MetalReshape.mm

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ Tensor view(const Tensor& input, c10::SymIntArrayRef sym_size) {
6464

6565
Tensor reshape(const Tensor& input, IntArrayRef shape) {
6666
TORCH_CHECK(input.is_metal());
67-
return view(input, c10::fromIntArrayRef(shape));
67+
return view(input, c10::fromIntArrayRefSlow(shape));
6868
}
6969

7070
Tensor flatten_using_ints(

0 commit comments

Comments
 (0)