Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions aten/src/ATen/core/OpsAlreadyMovedToC10.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ const std::unordered_set<c10::OperatorName>& aten_ops_already_moved_to_c10() {
{"aten::expm1", ""},
{"aten::expand", ""},
{"aten::expand_as", ""},
{"aten::flatten", ""},
{"aten::flatten", "using_ints"},
{"aten::fill_", "Scalar"},
{"aten::fill_", "Tensor"},
{"aten::floor", ""},
Expand Down Expand Up @@ -730,13 +730,13 @@ const std::unordered_set<c10::OperatorName>& aten_ops_not_moved_to_c10_yet() {
{"aten::eye", "out"},
{"aten::eye", "m_out"},
#ifdef BUILD_NAMEDTENSOR
{"aten::flatten", ""},
{"aten::flatten", "named_out_dim"},
#endif
#ifdef BUILD_NAMEDTENSOR
{"aten::flatten", ""},
{"aten::flatten", "using_names"},
#endif
#ifdef BUILD_NAMEDTENSOR
{"aten::flatten", ""},
{"aten::flatten", "DimnameList"},
#endif
{"aten::floor_", ""},
{"aten::floor", "out"},
Expand Down
8 changes: 4 additions & 4 deletions aten/src/ATen/core/TensorMethods.h
Original file line number Diff line number Diff line change
Expand Up @@ -966,7 +966,7 @@ inline Tensor Tensor::flatten(int64_t start_dim, int64_t end_dim) const {
#ifdef USE_STATIC_DISPATCH
return TypeDefault::flatten(const_cast<Tensor&>(*this), start_dim, end_dim);
#else
static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::flatten", ""}).value();
static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::flatten", "using_ints"}).value();
return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(type_set()))
.callUnboxed<Tensor, const Tensor &, int64_t, int64_t>(const_cast<Tensor&>(*this), start_dim, end_dim);
#endif
Expand All @@ -976,7 +976,7 @@ inline Tensor Tensor::flatten(int64_t start_dim, int64_t end_dim, Dimname out_di
#ifdef USE_STATIC_DISPATCH
return TypeDefault::flatten(const_cast<Tensor&>(*this), start_dim, end_dim, out_dim);
#else
static auto table = globalATenDispatch().getOpTable("aten::flatten(Tensor self, int start_dim, int end_dim, Dimname out_dim) -> Tensor");
static auto table = globalATenDispatch().getOpTable("aten::flatten.named_out_dim(Tensor self, int start_dim, int end_dim, Dimname out_dim) -> Tensor");
return table->getOp<Tensor (const Tensor &, int64_t, int64_t, Dimname)>(type_set())(const_cast<Tensor&>(*this), start_dim, end_dim, out_dim);
#endif
}
Expand All @@ -986,7 +986,7 @@ inline Tensor Tensor::flatten(Dimname start_dim, Dimname end_dim, Dimname out_di
#ifdef USE_STATIC_DISPATCH
return TypeDefault::flatten(const_cast<Tensor&>(*this), start_dim, end_dim, out_dim);
#else
static auto table = globalATenDispatch().getOpTable("aten::flatten(Tensor self, Dimname start_dim, Dimname end_dim, Dimname out_dim) -> Tensor");
static auto table = globalATenDispatch().getOpTable("aten::flatten.using_names(Tensor self, Dimname start_dim, Dimname end_dim, Dimname out_dim) -> Tensor");
return table->getOp<Tensor (const Tensor &, Dimname, Dimname, Dimname)>(type_set())(const_cast<Tensor&>(*this), start_dim, end_dim, out_dim);
#endif
}
Expand All @@ -996,7 +996,7 @@ inline Tensor Tensor::flatten(DimnameList dims, Dimname out_dim) const {
#ifdef USE_STATIC_DISPATCH
return TypeDefault::flatten(const_cast<Tensor&>(*this), dims, out_dim);
#else
static auto table = globalATenDispatch().getOpTable("aten::flatten(Tensor self, DimnameList dims, Dimname out_dim) -> Tensor");
static auto table = globalATenDispatch().getOpTable("aten::flatten.DimnameList(Tensor self, DimnameList dims, Dimname out_dim) -> Tensor");
return table->getOp<Tensor (const Tensor &, DimnameList, Dimname)>(type_set())(const_cast<Tensor&>(*this), dims, out_dim);
#endif
}
Expand Down
8 changes: 4 additions & 4 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1060,20 +1060,20 @@
CPU: eye_out_cpu
CUDA: eye_out_cuda

- func: flatten(Tensor self, int start_dim=0, int end_dim=-1) -> Tensor
- func: flatten.using_ints(Tensor self, int start_dim=0, int end_dim=-1) -> Tensor
use_c10_dispatcher: True
variants: function, method
named_guard: False

- func: flatten(Tensor self, int start_dim, int end_dim, Dimname out_dim) -> Tensor
- func: flatten.named_out_dim(Tensor self, int start_dim, int end_dim, Dimname out_dim) -> Tensor
variants: function, method
named_guard: False

- func: flatten(Tensor self, Dimname start_dim, Dimname end_dim, Dimname out_dim) -> Tensor
- func: flatten.using_names(Tensor self, Dimname start_dim, Dimname end_dim, Dimname out_dim) -> Tensor
variants: function, method
named_guard: False

- func: flatten(Tensor self, DimnameList dims, Dimname out_dim) -> Tensor
- func: flatten.DimnameList(Tensor self, DimnameList dims, Dimname out_dim) -> Tensor
variants: function, method
named_guard: False

Expand Down