Skip to content

Commit 27ecdf9

Browse files
committed
Update on "Simplify copy kernel"
Using the new type promotion and dynamic casting added to `TensorIterator`, the copy kernels could be greatly simplified. For benchmark, see #28352 (comment) [ghstack-poisoned]
2 parents 25e7b33 + bb3ad0c commit 27ecdf9

File tree

4 files changed

+8
-15
lines changed

4 files changed

+8
-15
lines changed

aten/src/ATen/native/TensorIterator.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -235,10 +235,8 @@ void TensorIterator::compute_types() {
235235
}
236236
}
237237

238-
if (op.tensor.defined()) {
239-
if (op.tensor.scalar_type() != common_dtype_) {
240-
has_promotion_ = true;
241-
}
238+
if (op.tensor.defined() && op.tensor.scalar_type() != common_dtype_) {
239+
have_differing_types_ = true;
242240
}
243241

244242
if (op.tensor.defined() && op.device != op.tensor.device()) {

aten/src/ATen/native/TensorIterator.h

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -287,8 +287,8 @@ struct CAFFE2_API TensorIterator {
287287
/// CUDA reductions.
288288
bool is_final_output() const { return final_output_; }
289289

290-
bool has_promotion() const {
291-
return has_promotion_;
290+
bool needs_dynamic_casting() const {
291+
return (common_dtype_strategy_ != CommonDTypeStrategy::NONE) && have_differing_types_;
292292
}
293293

294294
void set_check_mem_overlap(bool check_mem_overlap) {
@@ -330,10 +330,6 @@ struct CAFFE2_API TensorIterator {
330330
resize_outputs_ = false;
331331
}
332332

333-
void set_common_dtype(ScalarType dtype) {
334-
common_dtype_ = dtype;
335-
}
336-
337333
void build();
338334

339335
protected:
@@ -370,7 +366,7 @@ struct CAFFE2_API TensorIterator {
370366
bool promote_gpu_output_dtypes_ = false;
371367
bool final_output_ = true;
372368
bool check_mem_overlap_ = false;
373-
bool has_promotion_ = false;
369+
bool have_differing_types_ = false;
374370
};
375371
/// A container-like struct that acts as if it contains splits of a
376372
/// TensorIterator that can use 32-bit indexing. Taken together the splits cover

aten/src/ATen/native/cuda/Copy.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,8 @@ static void copy_device_to_device(TensorIterator& iter, bool non_blocking) {
6060
} else {
6161
// this is done intentionally done after build because copy has a "promotion"
6262
// rule that always "promote" to target dtype.
63-
iter.set_common_dtype(iter.dtype());
6463
iter.promote_common_dtype();
65-
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.dtype(0), "copy_", [&] {
64+
AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBool, kBFloat16, iter.dtype(0), "copy_", [&] {
6665
gpu_kernel(iter, []GPU_LAMBDA(scalar_t x) { return x; });
6766
});
6867
}

aten/src/ATen/native/cuda/Loops.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ void gpu_kernel_impl(TensorIterator& iter, const func_t& f) {
158158
strides[i] = inner_strides[i];
159159
}
160160

161-
if (iter.has_promotion()) {
161+
if (iter.needs_dynamic_casting()) {
162162
launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) {
163163
void* out = data[0] + strides[0] * idx;
164164
arg0_t result = invoke(f, &data.data[1], &strides.data[1], &dtypes.data[1], idx);
@@ -172,7 +172,7 @@ void gpu_kernel_impl(TensorIterator& iter, const func_t& f) {
172172
}
173173
} else {
174174
auto offset_calc = make_offset_calculator<traits::arity + 1>(iter);
175-
if (iter.has_promotion()) {
175+
if (iter.needs_dynamic_casting()) {
176176
launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) {
177177
auto offsets = offset_calc.get(idx);
178178
void* out = data[0] + offsets[0];

0 commit comments

Comments
 (0)