Skip to content

Commit dd612d0

Browse files
committed
Update on "Simplify copy kernel"
Using the new type promotion and dynamic casting added to `TensorIterator`, the copy kernels could be greatly simplified. For benchmark, see #28352 (comment) [ghstack-poisoned]
1 parent 44b398e commit dd612d0

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

aten/src/ATen/native/TensorIterator.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ void TensorIterator::compute_types() {
189189
auto operands = compute_common_dtype_only_for_inputs ? at::ArrayRef<OperandInfo>(operands_).slice(noutputs()) : operands_;
190190
auto common_type = compute_common_type_(operands);
191191
auto common_device = std::get<0>(common_type);
192-
bool common_device_is_cuda = common_device.is_cuda();
192+
common_device_is_cuda = common_device.is_cuda();
193193
common_dtype_ = std::get<1>(common_type);
194194
may_have_differing_types_ = !std::get<2>(common_type);
195195
bool has_cpu_scalar = false;

aten/src/ATen/native/cuda/Copy.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,10 @@ static void copy_device_to_device(TensorIterator& iter, bool non_blocking) {
5858
cudaMemcpyDeviceToDevice,
5959
copy_stream));
6060
} else {
61+
// this is done intentionally done after build because copy has a "promotion"
62+
// rule that always "promote" to target dtype.
6163
iter.set_common_dtype(iter.dtype());
64+
iter.promote_common_dtype();
6265
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.dtype(0), "copy_", [&] {
6366
gpu_kernel(iter, []GPU_LAMBDA(scalar_t x) { return x; });
6467
});

0 commit comments

Comments
 (0)