Skip to content

Commit e84ac97

Browse files
committed
Update on "Simplify copy kernel"
Using the new type promotion and dynamic casting added to `TensorIterator`, the copy kernels could be greatly simplified. For benchmark, see #28352 (comment) [ghstack-poisoned]
2 parents 0cf634c + 66826d1 commit e84ac97

File tree

112 files changed

+1736
-1178
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

112 files changed

+1736
-1178
lines changed

.jenkins/caffe2/test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ if [[ "$BUILD_ENVIRONMENT" == *onnx* ]]; then
135135
# default pip version is too old(9.0.2), unable to support tag `manylinux2010`.
136136
# Fix the pip error: Couldn't find a version that satisfies the requirement
137137
sudo pip install --upgrade pip
138-
pip install -q --user -i https://test.pypi.org/simple/ ort-nightly==0.5.0.dev1012
138+
pip install -q --user -i https://test.pypi.org/simple/ ort-nightly==0.5.0.dev1020
139139
fi
140140
"$ROOT_DIR/scripts/onnx/test.sh"
141141
fi

android/pytorch_android/src/main/cpp/pytorch_jni.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,6 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
615615
}
616616
auto output = [&]() {
617617
torch::autograd::AutoGradMode guard(false);
618-
at::AutoNonVariableTypeMode non_var_type_mode(true);
619618
return module_.forward(std::move(inputs));
620619
}();
621620
return JIValue::newJIValueFromAtIValue(output);
@@ -638,7 +637,6 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
638637
if (auto method = module_.find_method(methodName)) {
639638
auto output = [&]() {
640639
torch::autograd::AutoGradMode guard(false);
641-
at::AutoNonVariableTypeMode non_var_type_mode(true);
642640
return (*method)(std::move(inputs));
643641
}();
644642
return JIValue::newJIValueFromAtIValue(output);

aten/src/ATen/Declarations.cwrap

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -762,20 +762,6 @@
762762
output: True
763763
- THTensor* self
764764
]]
765-
[[
766-
name: _th_log1p
767-
cname: log1p
768-
types:
769-
- floating_point
770-
backends:
771-
- CUDA
772-
variants: function
773-
return: argument 0
774-
arguments:
775-
- arg: THTensor* result
776-
output: True
777-
- THTensor* self
778-
]]
779765
[[
780766
name: _th_exp
781767
cname: exp
@@ -944,20 +930,6 @@
944930
output: True
945931
- THTensor* self
946932
]]
947-
[[
948-
name: _th_sqrt
949-
cname: sqrt
950-
types:
951-
- floating_point
952-
backends:
953-
- CUDA
954-
variants: function
955-
return: argument 0
956-
arguments:
957-
- arg: THTensor* result
958-
output: True
959-
- THTensor* self
960-
]]
961933
[[
962934
name: _th_frac_
963935
types:

aten/src/ATen/core/ATenDispatch.h

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include <mutex>
1515
#include <ATen/core/interned_strings.h>
1616
#include <ATen/core/stack.h>
17+
#include <torch/csrc/jit/script/function_schema_parser.h>
1718

1819
// TODO: Rewrite this comment
1920
//
@@ -74,7 +75,7 @@ namespace detail {
7475
}
7576
}
7677

77-
using FallbackBoxedFunction = void(const char* schema, torch::jit::Stack*);
78+
using FallbackBoxedFunction = void(const c10::FunctionSchema& schema, torch::jit::Stack*);
7879

7980
// Assume T is decayed
8081
template <typename T>
@@ -129,9 +130,19 @@ class CAFFE2_API ATenOpTable {
129130

130131
C10_NORETURN void reportError(TensorTypeId tid) const;
131132

133+
const FunctionSchema& function_schema() const {
134+
std::lock_guard<std::mutex> lock(mutex_);
135+
if (!parsed_schema_.has_value()) {
136+
parsed_schema_ = torch::jit::parseSchema(schema_);
137+
}
138+
return *parsed_schema_;
139+
}
140+
132141
friend class ATenDispatch;
133142

134143
std::string schema_;
144+
mutable c10::optional<c10::FunctionSchema> parsed_schema_ = c10::nullopt;
145+
mutable std::mutex mutex_;
135146
void* function_table_[static_cast<int64_t>(TensorTypeId::NumTensorIds)] = {nullptr};
136147
};
137148

@@ -141,9 +152,9 @@ class CAFFE2_API ATenDispatch {
141152
ATenDispatch& registerOp(TensorTypeId id, const char* schema, FuncType* fn) {
142153
std::lock_guard<std::mutex> lock(mutex_);
143154
if (op_tables_.find(schema) == op_tables_.end()) {
144-
op_tables_.insert(std::make_pair(schema, ATenOpTable(schema)));
155+
op_tables_.insert(std::make_pair(schema, c10::guts::make_unique<ATenOpTable>(schema)));
145156
}
146-
op_tables_.at(schema).registerOp(id, reinterpret_cast<void*>(fn));
157+
op_tables_.at(schema)->registerOp(id, reinterpret_cast<void*>(fn));
147158
return *this;
148159
}
149160

@@ -157,23 +168,23 @@ class CAFFE2_API ATenDispatch {
157168
auto iter = op_tables_.find(schema);
158169
TORCH_CHECK(iter != op_tables_.end(),
159170
"No functions are registered for schema ", schema);
160-
return &iter->second;
171+
return iter->second.get();
161172
}
162173

163174
FallbackBoxedFunction* getFallbackBoxedOp(TensorTypeId tid) const {
164175
return boxed_fallback_table_[static_cast<size_t>(tid)];
165176
}
166177

167178
private:
168-
std::unordered_map<std::string, ATenOpTable> op_tables_;
179+
std::unordered_map<std::string, std::unique_ptr<ATenOpTable>> op_tables_;
169180
FallbackBoxedFunction* boxed_fallback_table_[static_cast<int64_t>(TensorTypeId::NumTensorIds)] = {nullptr};
170181
std::mutex mutex_;
171182
};
172183

173184
CAFFE2_API ATenDispatch& globalATenDispatch();
174185

175186
template<class Result, class... Args>
176-
Result callBoxedFallback(const char* schema, FallbackBoxedFunction* boxed_fallback_fn, Args&&... args,
187+
Result callBoxedFallback(const c10::FunctionSchema& schema, FallbackBoxedFunction* boxed_fallback_fn, Args&&... args,
177188
// NB: enable_if must occur in function parameter, because MSVC
178189
// doesn't like it when it's a template argument next to
179190
// a parameter pack
@@ -189,7 +200,7 @@ Result callBoxedFallback(const char* schema, FallbackBoxedFunction* boxed_fallba
189200

190201
template<
191202
class Result, class... Args>
192-
Result callBoxedFallback(const char* schema, FallbackBoxedFunction* boxed_fallback_fn, Args&&... args,
203+
Result callBoxedFallback(const c10::FunctionSchema& schema, FallbackBoxedFunction* boxed_fallback_fn, Args&&... args,
193204
typename c10::guts::enable_if_t<
194205
supports_boxed_fallback<Result, Args...>::value,
195206
std::nullptr_t
@@ -232,7 +243,7 @@ Result ATenOpTable::callUnboxed(Args... args) const {
232243
auto* boxed_fallback_fn = globalATenDispatch().getFallbackBoxedOp(tid);
233244
if (C10_UNLIKELY(boxed_fallback_fn)) {
234245
if (supports_boxed_fallback<Result, Args...>::value) {
235-
return callBoxedFallback<Result, Args...>(schema_.c_str(), boxed_fallback_fn, std::forward<Args>(args)...);
246+
return callBoxedFallback<Result, Args...>(function_schema(), boxed_fallback_fn, std::forward<Args>(args)...);
236247
} else {
237248
TORCH_INTERNAL_ASSERT(0, schema_, " does not support boxed fallback, but boxed fallback for ", tid, " was available");
238249
}

aten/src/ATen/cpu/vec256/vec256.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,17 +145,27 @@ inline interleave2<double>(const Vec256<double>& a, const Vec256<double>& b) {
145145
// swap lanes:
146146
// a_swapped = {a0, a1, b0, b1}
147147
// b_swapped = {a2, a3, b2, b3}
148+
#if __cpp_binary_literals >= 201304L
149+
auto a_swapped = _mm256_permute2f128_pd(a, b, 0b0100000);
150+
auto b_swapped = _mm256_permute2f128_pd(a, b, 0b0110001);
151+
#else // TODO Remove else case once switch to C++14 is finished
148152
static constexpr int swap_ctrl_a = 0 | (2 << 4); // 0, 2. 4 bits apart
149153
static constexpr int swap_ctrl_b = 1 | (3 << 4); // 1, 3. 4 bits apart
150154
auto a_swapped = _mm256_permute2f128_pd(a, b, swap_ctrl_a);
151155
auto b_swapped = _mm256_permute2f128_pd(a, b, swap_ctrl_b);
156+
#endif
152157

153158
// group cols crossing lanes:
154159
// return {a0, b0, a1, b1}
155160
// {a2, b2, a3, b3}
161+
#if __cpp_binary_literals >= 201304L
162+
return std::make_pair(_mm256_permute4x64_pd(a_swapped, 0b11011000),
163+
_mm256_permute4x64_pd(b_swapped, 0b11011000));
164+
#else // TODO Remove else case once switch to C++14 is finished
156165
static constexpr int group_ctrl = 0 | (2 << 2) | (1 << 4) | (3 << 6); // 0, 2, 1, 3
157166
return std::make_pair(_mm256_permute4x64_pd(a_swapped, group_ctrl),
158167
_mm256_permute4x64_pd(b_swapped, group_ctrl));
168+
#endif
159169
}
160170

161171
template <>
@@ -169,10 +179,15 @@ inline interleave2<float>(const Vec256<float>& a, const Vec256<float>& b) {
169179
// a_swapped = {a0, a1, a2, a3, b0, b1, b2, b3}
170180
// b_swapped = {a4, a5, a6, a7, b4, b5, b6, b7}
171181
// TODO: can we support caching this?
182+
#if __cpp_binary_literals >= 201304L
183+
auto a_swapped = _mm256_permute2f128_ps(a, b, 0b0100000);
184+
auto b_swapped = _mm256_permute2f128_ps(a, b, 0b0110001);
185+
#else // TODO Remove else case once switch to C++14 is finished
172186
static constexpr int swap_ctrl_a = 0 | (2 << 4); // 0, 2. 4 bits apart
173187
static constexpr int swap_ctrl_b = 1 | (3 << 4); // 1, 3. 4 bits apart
174188
auto a_swapped = _mm256_permute2f128_ps(a, b, swap_ctrl_a);
175189
auto b_swapped = _mm256_permute2f128_ps(a, b, swap_ctrl_b);
190+
#endif
176191

177192
// group cols crossing lanes:
178193
// return {a0, b0, a1, b1, a2, b2, a3, b3}
@@ -194,17 +209,27 @@ inline deinterleave2<double>(const Vec256<double>& a, const Vec256<double>& b) {
194209
// group cols crossing lanes:
195210
// a_grouped = {a0, a1, b0, b1}
196211
// b_grouped = {a2, a3, b2, b3}
212+
#if __cpp_binary_literals >= 201304L
213+
auto a_grouped = _mm256_permute4x64_pd(a, 0b11011000);
214+
auto b_grouped = _mm256_permute4x64_pd(b, 0b11011000);
215+
#else // TODO Remove else case once switch to C++14 is finished
197216
static constexpr int group_ctrl = 0 | (2 << 2) | (1 << 4) | (3 << 6); // 0, 2, 1, 3
198217
auto a_grouped = _mm256_permute4x64_pd(a, group_ctrl);
199218
auto b_grouped = _mm256_permute4x64_pd(b, group_ctrl);
219+
#endif
200220

201221
// swap lanes:
202222
// return {a0, a1, a2, a3}
203223
// {b0, b1, b2, b3}
224+
#if __cpp_binary_literals >= 201304L
225+
return std::make_pair(_mm256_permute2f128_pd(a_grouped, b_grouped, 0b0100000),
226+
_mm256_permute2f128_pd(a_grouped, b_grouped, 0b0110001));
227+
#else // TODO Remove else case once switch to C++14 is finished
204228
static constexpr int swap_ctrl_a = 0 | (2 << 4); // 0, 2. 4 bits apart
205229
static constexpr int swap_ctrl_b = 1 | (3 << 4); // 1, 3. 4 bits apart
206230
return std::make_pair(_mm256_permute2f128_pd(a_grouped, b_grouped, swap_ctrl_a),
207231
_mm256_permute2f128_pd(a_grouped, b_grouped, swap_ctrl_b));
232+
#endif
208233
}
209234

210235
template <>
@@ -225,10 +250,15 @@ inline deinterleave2<float>(const Vec256<float>& a, const Vec256<float>& b) {
225250
// swap lanes:
226251
// return {a0, a1, a2, a3, a4, a5, a6, a7}
227252
// {b0, b1, b2, b3, b4, b5, b6, b7}
253+
#if __cpp_binary_literals >= 201304L
254+
return std::make_pair(_mm256_permute2f128_ps(a_grouped, b_grouped, 0b0100000),
255+
_mm256_permute2f128_ps(a_grouped, b_grouped, 0b0110001));
256+
#else // TODO Remove else case once switch to C++14 is finished
228257
static constexpr int swap_ctrl_a = 0 | (2 << 4); // 0, 2. 4 bits apart
229258
static constexpr int swap_ctrl_b = 1 | (3 << 4); // 1, 3. 4 bits apart
230259
return std::make_pair(_mm256_permute2f128_ps(a_grouped, b_grouped, swap_ctrl_a),
231260
_mm256_permute2f128_ps(a_grouped, b_grouped, swap_ctrl_b));
261+
#endif
232262
}
233263

234264
#endif // defined(__AVX2__)

aten/src/ATen/function_wrapper.py

Lines changed: 50 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def TypedDict(name, attrs, total=True): # type: ignore
174174
#ifdef USE_STATIC_DISPATCH
175175
${static_dispatch_method_body}
176176
#else
177-
static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::${name}", "${overload_name}"}).value();
177+
static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::${operator_name}", "${overload_name}"}).value();
178178
return c10::Dispatcher::singleton().callUnboxedOnly<${formals_types_with_return}>(
179179
op, impl::dispatchTypeId(${inferred_type_set})${method_actuals_with_comma_prefix});
180180
#endif
@@ -185,7 +185,7 @@ def TypedDict(name, attrs, total=True): # type: ignore
185185
#ifdef USE_STATIC_DISPATCH
186186
${static_dispatch_method_body}
187187
#else
188-
static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::${name}", "${overload_name}"}).value();
188+
static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::${operator_name}", "${overload_name}"}).value();
189189
return c10::Dispatcher::singleton().callUnboxed<${formals_types_with_return}>(
190190
op, impl::dispatchTypeId(${inferred_type_set})${method_actuals_with_comma_prefix});
191191
#endif
@@ -217,7 +217,7 @@ def TypedDict(name, attrs, total=True): # type: ignore
217217
${static_dispatch_function_body}
218218
#else
219219
static c10::OperatorHandle op = c10::Dispatcher::singleton()
220-
.findSchema({"aten::${name}", "${overload_name}"}).value();
220+
.findSchema({"aten::${operator_name}", "${overload_name}"}).value();
221221
return c10::Dispatcher::singleton().callUnboxedOnly<${formals_types_with_return}>(
222222
op, impl::dispatchTypeId(${inferred_type_set})${native_actuals_with_comma_prefix});
223223
#endif
@@ -230,7 +230,7 @@ def TypedDict(name, attrs, total=True): # type: ignore
230230
${static_dispatch_function_body}
231231
#else
232232
static c10::OperatorHandle op = c10::Dispatcher::singleton()
233-
.findSchema({"aten::${name}", "${overload_name}"}).value();
233+
.findSchema({"aten::${operator_name}", "${overload_name}"}).value();
234234
return c10::Dispatcher::singleton().callUnboxed<${formals_types_with_return}>(
235235
op, impl::dispatchTypeId(${inferred_type_set})${native_actuals_with_comma_prefix});
236236
#endif
@@ -239,10 +239,17 @@ def TypedDict(name, attrs, total=True): # type: ignore
239239

240240
# In order to rely on the linker to strip unused ops, it requires us to dispatch statically
241241
# in Functions.h and TensorMethods.h.
242+
#
243+
# NB: The default body also needs to apply a variable guard, as in some
244+
# situations what we think is a default body actually does have an
245+
# explicit derivative, and thereby would have gotten unwrapped by
246+
# the time you get to the implementation.
242247
STATIC_DISPATCH_FUNCTION_DEFAULT_BODY = CodeTemplate("""\
248+
at::AutoNonVariableTypeMode _var_guard(true);
243249
${return_call} TypeDefault::${native_type_method_dispatch}(${native_arguments});
244250
""")
245251
STATIC_DISPATCH_FUNCTION_SWITCH_BODY = CodeTemplate("""\
252+
at::AutoNonVariableTypeMode _var_guard(true);
246253
switch(tensorTypeIdToBackend(impl::dispatchTypeId(${type_set}))) {
247254
${static_dispatch_function_switches}
248255
default:
@@ -272,6 +279,32 @@ def TypedDict(name, attrs, total=True): # type: ignore
272279
#endif
273280
}
274281
""")
282+
C10_UNBOXEDONLY_FACTORY_DEFINITION = CodeTemplate("""\
283+
static inline ${return_type} ${api_name}(${formals}) {
284+
#ifdef USE_STATIC_DISPATCH
285+
${static_dispatch_function_body}
286+
#else
287+
globalLegacyTypeDispatch().initForTensorTypeSet(${inferred_type_set});
288+
static c10::OperatorHandle op = c10::Dispatcher::singleton()
289+
.findSchema({"aten::${operator_name}", "${overload_name}"}).value();
290+
return c10::Dispatcher::singleton().callUnboxedOnly<${formals_types_with_return}>(
291+
op, impl::dispatchTypeId(${inferred_type_set})${native_actuals_with_comma_prefix});
292+
#endif
293+
}
294+
""")
295+
C10_FACTORY_DEFINITION = CodeTemplate("""\
296+
static inline ${return_type} ${api_name}(${formals}) {
297+
#ifdef USE_STATIC_DISPATCH
298+
${static_dispatch_function_body}
299+
#else
300+
globalLegacyTypeDispatch().initForTensorTypeSet(${inferred_type_set});
301+
static c10::OperatorHandle op = c10::Dispatcher::singleton()
302+
.findSchema({"aten::${operator_name}", "${overload_name}"}).value();
303+
return c10::Dispatcher::singleton().callUnboxed<${formals_types_with_return}>(
304+
op, impl::dispatchTypeId(${inferred_type_set})${native_actuals_with_comma_prefix});
305+
#endif
306+
}
307+
""")
275308

276309
ZERO_DIM_CHECK = CodeTemplate("""\
277310
if (${check_name}.dim() == 0) {
@@ -880,7 +913,9 @@ def get_return_types(option):
880913

881914
def format_return_type(return_types):
882915
# type: (List[ReturnType]) -> str
883-
if len(return_types) == 1:
916+
if len(return_types) == 0:
917+
return "void"
918+
elif len(return_types) == 1:
884919
return return_types[0]['type']
885920
return "std::tuple<{}>".format(','.join(r['type'] for r in return_types))
886921

@@ -1109,9 +1144,6 @@ def native_get_return_types(option):
11091144
if isinstance(t_raw, string_type):
11101145
t = t_raw
11111146
name = None
1112-
elif t_raw is None:
1113-
t = 'void'
1114-
name = None
11151147
else:
11161148
t = t_raw['type']
11171149
name = t_raw['name']
@@ -1142,7 +1174,6 @@ def process_native(option):
11421174
assert option['python_module'] == '' or option['python_module'] == 'nn', \
11431175
"Found python_module of {} for decl {}, but only \'\' string or \'nn\' are supported".format(
11441176
option['python_module'], option['name'])
1145-
11461177
formals = native_get_formals(option)
11471178
option['formals_list'] = formals
11481179
option['formals'] = [format_formal(f) for f in formals]
@@ -1263,8 +1294,16 @@ def gen_namespace_function(option, multidispatch_tensors):
12631294
option, native_arguments=option['native_actuals'])
12641295

12651296
if is_factory_method:
1266-
fn_definition = FACTORY_DEFINITION.substitute(
1267-
option, static_dispatch_function_body=static_dispatch_function_body)
1297+
if option['use_c10_dispatcher'] == 'no':
1298+
fn_definition = FACTORY_DEFINITION.substitute(
1299+
option, static_dispatch_function_body=static_dispatch_function_body)
1300+
elif option['use_c10_dispatcher'] == 'unboxed_only':
1301+
fn_definition = C10_UNBOXEDONLY_FACTORY_DEFINITION.substitute(
1302+
option, static_dispatch_function_body=static_dispatch_function_body)
1303+
else:
1304+
assert option['use_c10_dispatcher'] == 'full'
1305+
fn_definition = C10_FACTORY_DEFINITION.substitute(
1306+
option, static_dispatch_function_body=static_dispatch_function_body)
12681307
else:
12691308
if option['use_c10_dispatcher'] == 'no':
12701309
fn_definition = FUNCTION_DEFINITION.substitute(

0 commit comments

Comments
 (0)