Skip to content

Commit d3cbbca

Browse files
committed
Simplify copy kernel
Using the new type promotion and dynamic casting added to `TensorIterator`, the copy kernels could be greatly simplified. **Script:** ```python import torch import timeit import pandas import itertools from tqdm import tqdm import math print(torch.__version__) print() _10M = 10 * 1024 ** 2 d = {} for from_, to in tqdm(itertools.product(torch.testing.get_all_dtypes(), repeat=2)): if from_ not in d: d[from_] = {} a = torch.zeros(_10M, dtype=from_) min_ = math.inf for i in range(100): start = timeit.default_timer() a.to(to) end = timeit.default_timer() elapsed = end - start if elapsed < min_: min_ = elapsed d[from_][to] = int(elapsed * 1000 * 1000) pandas.DataFrame(d) ``` **Before:** ![image](https://user-images.githubusercontent.com/1032377/67171274-2e93d000-f36b-11e9-8fa0-91edd7dbc8ec.png) **After:** ![image](https://user-images.githubusercontent.com/1032377/67171200-d361dd80-f36a-11e9-9b22-66292e395a09.png) [ghstack-poisoned]
1 parent a3a32ff commit d3cbbca

File tree

5 files changed

+40
-72
lines changed

5 files changed

+40
-72
lines changed

aten/src/ATen/Dispatch.h

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -221,11 +221,11 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
221221
const auto& SCALAR_TYPE C10_UNUSED = TYPE; \
222222
switch (TYPE) { \
223223
AT_QINT_PRIVATE_CASE_TYPE( \
224-
at::kQInt8, at::qint8, at::kChar, int8_t, __VA_ARGS__) \
224+
at::kQInt8, at::qint8, at::kChar, int8_t, __VA_ARGS__) \
225225
AT_QINT_PRIVATE_CASE_TYPE( \
226-
at::kQUInt8, at::quint8, at::kByte, uint8_t, __VA_ARGS__) \
226+
at::kQUInt8, at::quint8, at::kByte, uint8_t, __VA_ARGS__) \
227227
AT_QINT_PRIVATE_CASE_TYPE( \
228-
at::kQInt32, at::qint32, at::kInt, int, __VA_ARGS__) \
228+
at::kQInt32, at::qint32, at::kInt, int, __VA_ARGS__) \
229229
default: \
230230
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
231231
} \
@@ -351,6 +351,29 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
351351
} \
352352
}()
353353

354+
#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND_QINTS_AND3(SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
355+
[&] { \
356+
switch (TYPE) { \
357+
AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \
358+
AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \
359+
AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \
360+
AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \
361+
AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \
362+
AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \
363+
AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \
364+
AT_PRIVATE_CASE_TYPE(at::ScalarType::ComplexFloat, std::complex<float>, __VA_ARGS__) \
365+
AT_PRIVATE_CASE_TYPE(at::ScalarType::ComplexDouble, std::complex<double>, __VA_ARGS__) \
366+
AT_QINT_PRIVATE_CASE_TYPE(at::kQInt8, at::qint8, at::kChar, int8_t, __VA_ARGS__) \
367+
AT_QINT_PRIVATE_CASE_TYPE(at::kQUInt8, at::quint8, at::kByte, uint8_t, __VA_ARGS__) \
368+
AT_QINT_PRIVATE_CASE_TYPE(at::kQInt32, at::qint32, at::kInt, int, __VA_ARGS__) \
369+
AT_PRIVATE_CASE_TYPE(SCALARTYPE1, decltype(c10::impl::ScalarTypeToCPPType<SCALARTYPE1>::t), __VA_ARGS__) \
370+
AT_PRIVATE_CASE_TYPE(SCALARTYPE2, decltype(c10::impl::ScalarTypeToCPPType<SCALARTYPE2>::t), __VA_ARGS__) \
371+
AT_PRIVATE_CASE_TYPE(SCALARTYPE3, decltype(c10::impl::ScalarTypeToCPPType<SCALARTYPE3>::t), __VA_ARGS__) \
372+
default: \
373+
AT_ERROR(#NAME, " not implemented for '", TYPE, "'"); \
374+
} \
375+
}()
376+
354377
// ----------------------------------------------------------------------------
355378
// DEPRECATED MACROS, DON'T USE THESE
356379
// ----------------------------------------------------------------------------

aten/src/ATen/native/Copy.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,11 +124,12 @@ static Tensor & copy_impl(Tensor & self, const Tensor & src, bool non_blocking)
124124

125125
auto iter = TensorIterator();
126126
iter.set_check_mem_overlap(true);
127+
iter.dont_compute_common_dtype();
127128
iter.add_output(self);
128129
iter.add_input(src);
129130
iter.dont_resize_outputs();
130-
iter.dont_compute_common_dtype();
131131
iter.build();
132+
iter.set_common_dtype(iter.dtype());
132133

133134
if (iter.numel() == 0) {
134135
return self;

aten/src/ATen/native/TensorIterator.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,10 @@ struct CAFFE2_API TensorIterator {
318318
resize_outputs_ = false;
319319
}
320320

321+
void set_common_dtype(ScalarType dtype) {
322+
common_dtype_ = dtype;
323+
}
324+
321325
void build();
322326

323327
protected:

aten/src/ATen/native/cpu/CopyKernel.cpp

Lines changed: 7 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -4,69 +4,20 @@
44
#include <ATen/native/Copy.h>
55
#include <ATen/native/TensorIterator.h>
66
#include <ATen/native/cpu/Loops.h>
7-
#include <c10/util/TypeCast.h>
87

98
namespace at {
109
namespace native {
1110
namespace {
1211

13-
template <typename self_T>
14-
void copy_kernel_cast(TensorIterator& iter) {
15-
if (isComplexType(iter.dtype(1))) {
16-
AT_DISPATCH_COMPLEX_TYPES(iter.dtype(1), "copy_kernel_cast", [&] {
17-
cpu_kernel(iter, [=](scalar_t a) -> self_T {
18-
return c10::static_cast_with_inter_type<self_T>(std::real(a));
19-
});
20-
});
21-
}
22-
else {
23-
AT_DISPATCH_ALL_TYPES_AND3(
24-
ScalarType::Half,
25-
ScalarType::Bool,
26-
ScalarType::BFloat16,
27-
iter.dtype(1),
28-
"copy_kernel_cast",
29-
[&] {
30-
cpu_kernel(iter, [=](scalar_t a) -> self_T {
31-
return c10::static_cast_with_inter_type<self_T>(a);
32-
});
33-
});
34-
}
35-
}
36-
3712
static void copy_kernel(TensorIterator& iter, bool non_blocking) {
38-
ScalarType dtype = iter.dtype(0);
39-
if (dtype == iter.dtype(1)) {
40-
if (dtype == ScalarType::Half) {
41-
cpu_kernel(iter, [=](at::Half a) -> at::Half { return a; });
42-
} else if (dtype == ScalarType::BFloat16) {
43-
cpu_kernel(iter, [=](at::BFloat16 a) -> at::BFloat16 { return a; });
44-
} else if (isQIntType(dtype)) {
45-
AT_DISPATCH_QINT_TYPES(dtype, "copy_kernel", [&] {
46-
cpu_kernel(
47-
iter,
48-
[=](scalar_t a) -> scalar_t {return a; });
49-
});
50-
} else if (isComplexType(dtype)) {
51-
AT_DISPATCH_COMPLEX_TYPES(dtype, "copy_kernel", [&] {
52-
cpu_kernel(
53-
iter,
54-
[=](scalar_t a) -> scalar_t { return a; });
55-
});
56-
} else {
57-
AT_DISPATCH_ALL_TYPES_AND(
58-
ScalarType::Bool, dtype, "copy_kernel", [&] {
59-
cpu_kernel_vec(
60-
iter,
61-
[=](scalar_t a) -> scalar_t { return a; },
62-
[=](Vec256<scalar_t> a) { return a; });
63-
});
64-
}
65-
} else {
66-
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(ScalarType::Half, ScalarType::Bool, ScalarType::BFloat16, dtype, "copy_", [&] {
67-
copy_kernel_cast<scalar_t>(iter);
13+
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND_QINTS_AND3(
14+
ScalarType::Half, ScalarType::Bool, ScalarType::BFloat16, iter.common_dtype(), "copy_",
15+
[&] {
16+
cpu_kernel_vec(
17+
iter,
18+
[](scalar_t a) -> scalar_t { return a; },
19+
[](Vec256<scalar_t> a) { return a; });
6820
});
69-
}
7021
}
7122

7223
} // anonymous namespace

aten/src/ATen/native/cuda/Copy.cu

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,12 @@
88
#include <ATen/native/TensorIterator.h>
99
#include <ATen/native/cuda/Loops.cuh>
1010
#include <THC/THC.h>
11-
#include <c10/util/TypeCast.h>
1211

1312
namespace at {
1413
namespace native {
1514

1615
using namespace at::cuda;
1716

18-
template <typename dst_t, typename src_t>
19-
void copy_kernel_impl(TensorIterator& iter) {
20-
gpu_kernel(iter, []GPU_LAMBDA(src_t x) -> dst_t {
21-
return c10::static_cast_with_inter_type<dst_t>(x);
22-
});
23-
}
24-
2517
// device-to-device copy, does type conversion
2618
static void copy_device_to_device(TensorIterator& iter, bool non_blocking) {
2719
int64_t numel = iter.numel();
@@ -67,10 +59,7 @@ static void copy_device_to_device(TensorIterator& iter, bool non_blocking) {
6759
copy_stream));
6860
} else {
6961
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.dtype(0), "copy_", [&] {
70-
using dst_t = scalar_t;
71-
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.dtype(1), "copy_", [&] {
72-
copy_kernel_impl<dst_t, scalar_t>(iter);
73-
});
62+
gpu_kernel(iter, []GPU_LAMBDA(scalar_t x) { return x; });
7463
});
7564
}
7665

0 commit comments

Comments
 (0)