Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions aten/src/ATen/native/TensorFactories.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ Tensor& arange_out(Tensor& result, Scalar end) {
return at::_arange_out(result, end);
}

Tensor _dim_arange(const Tensor& like, int64_t dim) {
return like.type().toScalarType(at::kLong)._arange(like.size(dim));
}

Tensor empty(const Type& dtype, IntList size) {
return dtype.tensor(size);
}
Expand Down
8 changes: 8 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,14 @@
- func: arange_out(Tensor result, Scalar end) -> Tensor
variants: function

# This function is a temporary hack to allow tracing of arange like constructs with dynamic
# bounds on arange. Normal arange is not traceable because it does not take any tensor inputs;
# if the range you need is based on another tensor, calling this function directly will
# preserve tracing. Get rid of this when arange can directly take tensors for bounds
# (so that it can be traced directly).
- func: _dim_arange(Tensor like, int64_t dim) -> Tensor
variants: function

# `argmin` and `argmax` are exposed in C++ but not in Python, where we only
# expose `_argmin` and `_argmax` (which call the first versions). In Python, we
# then define our own `argmax` and `argmin` that handle passing `dim=None`,
Expand Down
4 changes: 4 additions & 0 deletions torch/onnx/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,3 +998,7 @@ def retrieve_state(x, start, end):
return prev_output, h_outs, c_outs

return symbolic


def _dim_arange(g, like, dim):
return g.op('ATen', like, dim_i=dim, operator_s='_dim_arange')