Skip to content

Commit 0494e0a

Browse files
smessmerfacebook-github-bot
authored andcommitted
Back out "Revert D21581908: Move TensorOptions ops to c10" (#40595)
Summary: Pull Request resolved: #40595 ghstack-source-id: 106691774 Test Plan: waitforsandcastle Differential Revision: D22247729 fbshipit-source-id: 14745588cae267c1e0cc51cd9541a9b8abb830e5
1 parent b8f4f68 commit 0494e0a

26 files changed

+546
-64
lines changed

aten/src/ATen/common_with_cwrap.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# this code should be common among cwrap and ATen preprocessing
22
# for now, I have put it in one place but right now is copied out of cwrap
33

4+
import copy
45

56
def parse_arguments(args):
67
new_args = []
@@ -50,11 +51,16 @@ def set_declaration_defaults(declaration):
5051
declaration['unqual_operator_name_with_overload'] = ''
5152
# Simulate multiple dispatch, even if it's not necessary
5253
if 'options' not in declaration:
53-
declaration['options'] = [{'arguments': declaration['arguments']}]
54+
declaration['options'] = [{
55+
'arguments': copy.deepcopy(declaration['arguments']),
56+
'schema_order_arguments': copy.deepcopy(declaration['schema_order_arguments']),
57+
}]
5458
del declaration['arguments']
59+
del declaration['schema_order_arguments']
5560
# Parse arguments (some of them can be strings)
5661
for option in declaration['options']:
5762
option['arguments'] = parse_arguments(option['arguments'])
63+
option['schema_order_arguments'] = parse_arguments(option['schema_order_arguments'])
5864
# Propagate defaults from declaration to options
5965
for option in declaration['options']:
6066
for k, v in declaration.items():

aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,10 @@ using supported_primitive_arg_types = guts::typelist::typelist<
4242
at::Tensor,
4343
at::Scalar,
4444
c10::QScheme,
45-
c10::ScalarType
45+
c10::ScalarType,
46+
c10::Device,
47+
c10::Layout,
48+
c10::MemoryFormat
4649
>;
4750

4851
template<class T, bool AllowDeprecatedTypes, class Enable = void> struct assert_is_valid_input_type {

aten/src/ATen/core/jit_type.h

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1517,9 +1517,12 @@ namespace detail {
15171517
template <typename T>
15181518
struct getTypePtr_ final {
15191519
static TypePtr call() {
1520-
if (!isCustomClassRegistered<T>()) {
1521-
throw c10::Error("Type could not be converted to any of the known types.", "");
1522-
}
1520+
TORCH_CHECK(
1521+
isCustomClassRegistered<T>(),
1522+
"Type ",
1523+
c10::util::get_fully_qualified_type_name<T>(),
1524+
" could not be converted to any of the known types."
1525+
);
15231526
auto res = getCustomClassType<T>();
15241527
return std::dynamic_pointer_cast<Type>(std::move(res));
15251528
}
@@ -1557,6 +1560,24 @@ struct getTypePtr_<c10::ScalarType> final {
15571560
}
15581561
};
15591562
template <>
1563+
struct getTypePtr_<c10::Device> final {
1564+
static TypePtr call() {
1565+
return DeviceObjType::get();
1566+
}
1567+
};
1568+
template <>
1569+
struct getTypePtr_<c10::Layout> final {
1570+
static TypePtr call() {
1571+
return IntType::get();
1572+
}
1573+
};
1574+
template <>
1575+
struct getTypePtr_<c10::MemoryFormat> final {
1576+
static TypePtr call() {
1577+
return IntType::get();
1578+
}
1579+
};
1580+
template <>
15601581
struct getTypePtr_<bool> final {
15611582
static TypePtr call() {
15621583
return BoolType::get();
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
#pragma once
2+
3+
#include <c10/util/Metaprogramming.h>
4+
#include <c10/util/TypeList.h>
5+
#include <c10/core/TensorOptions.h>
6+
#include <c10/core/CompileTimeFunctionPointer.h>
7+
8+
// This file defines hacky_wrapper_for_legacy_signatures, which takes a kernel written in a legacy way
9+
// (e.g. with TensorOptions packed) and wraps it into a kernel with the signature expected by
10+
// the PyTorch operator library. The intention is to ultimately rewrite kernels to take the new signature
11+
// and then delete this file. This transition process can happen kernel-by-kernel, since this wrapper
12+
// is a no-op for kernels that already have a non-legacy signature.
13+
14+
namespace c10 {
15+
namespace impl {
16+
17+
inline c10::optional<MemoryFormat> process_memory_format(const TensorOptions& options, c10::optional<MemoryFormat> memory_format) {
18+
TORCH_CHECK(
19+
!(options.has_memory_format() && memory_format.has_value()),
20+
"Cannot set memory_format both in TensorOptions and explicit argument; please delete "
21+
"the redundant setter.");
22+
if (memory_format.has_value()) {
23+
return memory_format;
24+
} else {
25+
return options.memory_format_opt();
26+
}
27+
}
28+
29+
namespace detail {
30+
31+
// with_scattered_tensor_options takes a function pointer that potentially takes a TensorOptions argument.
32+
// If it does, then it creates a new function pointer that takes scattered arguments, internally
33+
// gathers those arguments, and then calls the underlying function pointer. If the underlying
34+
// function pointer does not take a TensorOptions argument, it is passed through unmodified.
35+
36+
template<class Type, class Enable = void> struct is_tensoroptions_arg : std::false_type {};
37+
template<class Type> struct is_tensoroptions_arg<Type, std::enable_if_t<std::is_same<TensorOptions, std::decay_t<Type>>::value>> : std::true_type {};
38+
template<class Type>
39+
using is_tensoroptions_arg_t = typename is_tensoroptions_arg<Type>::type;
40+
41+
template<class FuncType>
42+
inline constexpr bool has_tensoroptions_arg() {
43+
using parameter_types = typename guts::infer_function_traits_t<FuncType>::parameter_types;
44+
constexpr size_t num_tensoroptions_args = guts::typelist::count_if<is_tensoroptions_arg_t, parameter_types>::value;
45+
static_assert(num_tensoroptions_args <= 1, "Function has multiple TensorOptions parameters. We support at most one.");
46+
return num_tensoroptions_args > 0;
47+
}
48+
49+
// sanity checks
50+
static_assert(has_tensoroptions_arg<int (int64_t, const TensorOptions&)>(), "");
51+
static_assert(has_tensoroptions_arg<int (int64_t, TensorOptions)>(), "");
52+
static_assert(!has_tensoroptions_arg<int (int64_t, std::string)>(), "");
53+
54+
template<class FuncPtr, class ParametersBeforeTensorOptions, class ParametersAfterTensorOptions> struct with_scattered_tensor_options_;
55+
56+
template<class FuncPtr, class Enable = void>
57+
struct with_scattered_tensor_options final {};
58+
59+
template<class UnderlyingFuncPtr>
60+
struct with_scattered_tensor_options<UnderlyingFuncPtr, std::enable_if_t<!has_tensoroptions_arg<typename UnderlyingFuncPtr::FuncType>()>> final {
61+
// FuncType does not have TensorOptions arguments.
62+
// Don't wrap anything but just return the base pointer.
63+
using FuncPtr = UnderlyingFuncPtr;
64+
};
65+
66+
template<class UnderlyingFuncPtr>
67+
struct with_scattered_tensor_options<UnderlyingFuncPtr, std::enable_if_t<has_tensoroptions_arg<typename UnderlyingFuncPtr::FuncType>()>> final {
68+
private:
69+
// FuncType has TensorOptions arguments.
70+
// Return a function pointer to a wrapper function that replaces those with expanded arguments.
71+
using gathered_parameter_types = typename guts::infer_function_traits_t<typename UnderlyingFuncPtr::FuncType>::parameter_types;
72+
static constexpr size_t tensoroptions_arg_index =
73+
guts::typelist::find_if<
74+
gathered_parameter_types,
75+
is_tensoroptions_arg_t
76+
>::value;
77+
78+
using parameters_before_tensoroptions =
79+
guts::typelist::take_t<gathered_parameter_types, tensoroptions_arg_index>;
80+
using parameters_after_tensoroptions =
81+
guts::typelist::drop_t<gathered_parameter_types, tensoroptions_arg_index + 1>;
82+
83+
using wrapper = with_scattered_tensor_options_<UnderlyingFuncPtr, parameters_before_tensoroptions, parameters_after_tensoroptions>;
84+
public:
85+
using FuncPtr = TORCH_FN_TYPE(&wrapper::wrapper);
86+
};
87+
88+
template<class FuncPtr, class... ParametersBeforeTensorOptions, class... ParametersAfterTensorOptions>
89+
struct with_scattered_tensor_options_<FuncPtr, guts::typelist::typelist<ParametersBeforeTensorOptions...>, guts::typelist::typelist<ParametersAfterTensorOptions...>> final {
90+
static decltype(auto) wrapper(
91+
ParametersBeforeTensorOptions... parameters_before,
92+
optional<ScalarType> scalar_type,
93+
optional<Layout> layout,
94+
optional<Device> device,
95+
optional<bool> pin_memory,
96+
ParametersAfterTensorOptions... parameters_after) {
97+
return (*FuncPtr::func_ptr())(
98+
std::forward<ParametersBeforeTensorOptions>(parameters_before)...,
99+
TensorOptions().dtype(scalar_type).device(device).layout(layout).pinned_memory(pin_memory),
100+
std::forward<ParametersAfterTensorOptions>(parameters_after)...
101+
);
102+
}
103+
};
104+
105+
}
106+
107+
template<class FuncPtr>
108+
constexpr auto hacky_wrapper_for_legacy_signatures(FuncPtr) {
109+
return typename detail::with_scattered_tensor_options<FuncPtr>::FuncPtr();
110+
};
111+
112+
}
113+
}

aten/src/ATen/cwrap_parser.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import yaml
2+
import copy
3+
24
try:
35
# use faster C loader if available
46
from yaml import CLoader as Loader
@@ -24,4 +26,13 @@ def parse(filename):
2426
declarations.append(declaration)
2527
elif in_declaration:
2628
declaration_lines.append(line)
29+
declarations = [process_declaration(declaration) for declaration in declarations]
2730
return declarations
31+
32+
def process_declaration(declaration):
33+
declaration = copy.deepcopy(declaration)
34+
if "arguments" in declaration:
35+
declaration["schema_order_arguments"] = copy.deepcopy(declaration["arguments"])
36+
if "options" in declaration:
37+
declaration["options"] = [process_declaration(option) for option in declaration["options"]]
38+
return declaration

0 commit comments

Comments
 (0)