Skip to content

Commit b9f099e

Browse files
zasdfgbnmfacebook-github-bot
authored andcommitted
Make TensorIterator stop promoting types by copying (#28427)
Summary: Pull Request resolved: #28427 Fixes: #26401 This PR fixes the issue by using the newly added dynamic cast inside `TensorIterator` so that instead of converting the type at the beginning (which generates extra kernel launches), the `TensorIterator` do a load-cast-compute-store for each element while looping. So there is only one read and one write of memory. **nvprof:** ```python import torch _100M = 100 * 1024 ** 2 r = torch.randn(_100M, dtype=torch.float32, device='cuda') d = torch.randn(_100M, dtype=torch.float64, device='cuda') torch.cuda.synchronize() torch.cuda.profiler.start() r.add_(d) torch.cuda.profiler.stop() torch.cuda.synchronize() ``` ``` ==11407== NVPROF is profiling process 11407, command: /home/xgao/anaconda3/bin/python simple.py ==11407== Profiling application: /home/xgao/anaconda3/bin/python simple.py ==11407== Profiling result: Type Time(%) Time Calls Avg Min Max Name GPU activities: 100.00% 2.0611ms 1 2.0611ms 2.0611ms 2.0611ms _ZN2at6native18elementwise_kernelILi512ELi1EZNS0_15gpu_kernel_implIZZZNS0_15add_kernel_cudaERNS_14TensorIteratorEN3c106ScalarEENKUlvE_clEvENKUlvE1_clEvEUlddE_EEvS4_RKT_EUliE_EEviT1_ API calls: 100.00% 1.05006s 1 1.05006s 1.05006s 1.05006s cudaLaunchKernel 0.00% 2.7740us 2 1.3870us 673ns 2.1010us cudaGetDevice 0.00% 2.3730us 1 2.3730us 2.3730us 2.3730us cudaSetDevice 0.00% 830ns 1 830ns 830ns 830ns cudaGetLastError ``` **benchmark** ```python import torch print(torch.__version__) print(torch.version.git_version) _100M = 100 * 1024 ** 2 r = torch.randn(_100M, dtype=torch.float32, device='cuda') d = torch.randn(_100M, dtype=torch.float64, device='cuda') torch.cuda.synchronize() %timeit r.add_(d); torch.cuda.synchronize() ``` original ``` 1.4.0a0+7d277b0 7d277b0 6.83 ms ± 1.12 ms per loop (mean ± std. dev. of 7 runs, 100 loops each) ``` after ``` 1.4.0a0+f0f2f65 f0f2f65 2.08 ms ± 139 ns per loop (mean ± std. dev. of 7 runs, 100 loops each) ``` For more benchmark, see: #28344 Test Plan: Imported from OSS Differential Revision: D18170997 Pulled By: ezyang fbshipit-source-id: 9c82c1c89583f3e6202c5d790b9b73ad9f960fad
1 parent 688a9db commit b9f099e

File tree

5 files changed

+93
-44
lines changed

5 files changed

+93
-44
lines changed

aten/src/ATen/native/TensorIterator.cpp

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ static void validate_dtype(OperandInfo& op, ScalarType common_dtype, CommonDType
148148
}
149149
}
150150

151-
static void maybe_promote_common_dtype(OperandInfo& op, ScalarType common_dtype) {
151+
static void maybe_copy_casting_to_common_dtype(OperandInfo& op, ScalarType common_dtype) {
152152
if (op.tensor.defined() && op.tensor.scalar_type() != common_dtype)
153153
{
154154
op.dtype = common_dtype;
@@ -165,7 +165,7 @@ static void maybe_promote_common_dtype(OperandInfo& op, ScalarType common_dtype)
165165
void TensorIterator::compute_types() {
166166
bool missing_dtypes = false;
167167
bool missing_output_dtypes = false;
168-
ScalarType common_dtype = dtype();
168+
common_dtype_ = dtype();
169169
for (auto& op : operands_) {
170170
if (!op.tensor.defined() && !op.is_type_defined()) {
171171
missing_dtypes = true;
@@ -183,31 +183,33 @@ void TensorIterator::compute_types() {
183183
bool compute_common_dtype_only_for_inputs = (common_dtype_strategy_ == CommonDTypeStrategy::PROMOTE_INPUTS);
184184

185185
bool may_have_differing_types = true;
186+
bool common_device_is_cuda = false;
186187

187188
if (missing_dtypes || compute_common_dtype) {
188189
auto operands = compute_common_dtype_only_for_inputs ? at::ArrayRef<OperandInfo>(operands_).slice(noutputs()) : operands_;
189190
auto common_type = compute_common_type_(operands);
190191
auto common_device = std::get<0>(common_type);
191-
common_dtype = std::get<1>(common_type);
192+
common_device_is_cuda = common_device.is_cuda();
193+
common_dtype_ = std::get<1>(common_type);
192194
may_have_differing_types = !std::get<2>(common_type);
193195
bool has_cpu_scalar = false;
194196
for (auto& op : operands_) {
195197
if (!op.is_type_defined()) {
196198
op.device = common_device;
197-
op.dtype = common_dtype;
199+
op.dtype = common_dtype_;
198200
} else if (compute_common_dtype &&
199-
(op.device != common_device || op.dtype != common_dtype)) {
201+
(op.device != common_device || op.dtype != common_dtype_)) {
200202
if (allow_cpu_scalars_ && op.tensor.defined() && op.tensor.dim() == 0 &&
201-
common_device.is_cuda() && op.tensor.device().is_cpu() &&
203+
common_device_is_cuda && op.tensor.device().is_cpu() &&
202204
!has_cpu_scalar) {
203205
// don't cast CPU scalars in CUDA ops that directly support them.
204206
op.device = op.tensor.device();
205207
op.dtype = op.tensor.scalar_type();
206208
has_cpu_scalar = true;
207209
} else if (promote_gpu_output_dtypes_ && op.tensor.defined() &&
208210
!op.is_output &&
209-
op.tensor.scalar_type() == kHalf && common_dtype == kFloat &&
210-
op.tensor.device().is_cuda() && common_device.is_cuda()) {
211+
op.tensor.scalar_type() == kHalf && common_dtype_ == kFloat &&
212+
op.tensor.device().is_cuda() && common_device_is_cuda) {
211213
// allow input tensor type upcasting for fp16 to fp32 in fused kernel
212214
// on GPU
213215
op.device = op.tensor.device();
@@ -217,7 +219,7 @@ void TensorIterator::compute_types() {
217219
if (compute_common_dtype_only_for_inputs && op.is_output) {
218220
op.dtype = op.tensor.scalar_type();
219221
} else {
220-
op.dtype = common_dtype;
222+
op.dtype = common_dtype_;
221223
}
222224
}
223225
}
@@ -226,12 +228,17 @@ void TensorIterator::compute_types() {
226228

227229
for (auto &op : operands_) {
228230
if (may_have_differing_types) {
229-
validate_dtype(op, common_dtype, common_dtype_strategy_);
230-
if (compute_common_dtype && (!compute_common_dtype_only_for_inputs || !op.is_output)) {
231-
maybe_promote_common_dtype(op, common_dtype);
231+
validate_dtype(op, common_dtype_, common_dtype_strategy_);
232+
bool cast_by_copy = compute_common_dtype && !common_device_is_cuda && (!compute_common_dtype_only_for_inputs || !op.is_output);
233+
if (cast_by_copy) {
234+
maybe_copy_casting_to_common_dtype(op, common_dtype_);
232235
}
233236
}
234237

238+
if (op.tensor.defined() && op.tensor.scalar_type() != common_dtype_) {
239+
have_differing_types_ = true;
240+
}
241+
235242
if (op.tensor.defined() && op.device != op.tensor.device()) {
236243
if (op.is_output) {
237244
AT_ERROR("output with device ", op.tensor.device(),

aten/src/ATen/native/TensorIterator.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ struct CAFFE2_API TensorIterator {
191191
IntArrayRef strides(int arg) const { return operands_[arg].stride_bytes; }
192192
void* data_ptr(int arg) const;
193193
ScalarType dtype(int arg=0) const { return operands_[arg].tensor.scalar_type(); }
194+
ScalarType common_dtype() const { return common_dtype_; }
194195
ScalarType input_dtype(int arg=0) const { return operands_[num_outputs_ + arg].dtype; }
195196
Device device(int arg=0) const { return operands_[arg].device; }
196197
DeviceType device_type(int arg=0) const { return device(arg).type(); }
@@ -286,6 +287,10 @@ struct CAFFE2_API TensorIterator {
286287
/// CUDA reductions.
287288
bool is_final_output() const { return final_output_; }
288289

290+
bool needs_dynamic_casting() const {
291+
return (common_dtype_strategy_ != CommonDTypeStrategy::NONE) && have_differing_types_;
292+
}
293+
289294
void set_check_mem_overlap(bool check_mem_overlap) {
290295
check_mem_overlap_ = check_mem_overlap;
291296
}
@@ -352,6 +357,7 @@ struct CAFFE2_API TensorIterator {
352357
SmallVector<OperandInfo, 4> operands_;
353358
int num_outputs_ = 0;
354359
CommonDTypeStrategy common_dtype_strategy_ = CommonDTypeStrategy::CHECK;
360+
ScalarType common_dtype_ = ScalarType::Undefined;
355361
bool has_coalesced_dimensions_ = false;
356362
bool accumulate_ = false;
357363
bool resize_outputs_ = true;
@@ -360,6 +366,7 @@ struct CAFFE2_API TensorIterator {
360366
bool promote_gpu_output_dtypes_ = false;
361367
bool final_output_ = true;
362368
bool check_mem_overlap_ = false;
369+
bool have_differing_types_ = false;
363370
};
364371
/// A container-like struct that acts as if it contains splits of a
365372
/// TensorIterator that can use 32-bit indexing. Taken together the splits cover

aten/src/ATen/native/cuda/BinaryOpsKernel.cu

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
namespace at { namespace native {
1414

1515
void add_kernel_cuda(TensorIterator& iter, Scalar alpha_scalar) {
16-
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.dtype(), "add_cuda/sub_cuda", [&]() {
16+
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.common_dtype(), "add_cuda/sub_cuda", [&]() {
1717
auto alpha = alpha_scalar.to<scalar_t>();
1818
gpu_kernel_with_scalars(iter, [alpha]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
1919
return a + alpha * b;
@@ -26,19 +26,19 @@ static void sub_kernel_cuda(TensorIterator& iter, Scalar alpha_scalar) {
2626
}
2727

2828
void div_kernel_cuda(TensorIterator& iter) {
29-
if (!isIntegralType(iter.dtype(), /*includeBool*/ false) && iter.is_cpu_scalar(2)) {
29+
if (!isIntegralType(iter.common_dtype(), /*includeBool*/ false) && iter.is_cpu_scalar(2)) {
3030
// optimization for floating-point types: if the second operand is a CPU
3131
// scalar, compute a * reciprocal(b). Note that this may lose one bit of
3232
// precision compared to computing the division.
33-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "div_cuda", [&]() {
33+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "div_cuda", [&]() {
3434
auto inv_b = scalar_t(1.0 / iter.scalar_value<scalar_t>(2));
3535
iter.remove_operand(2);
3636
gpu_kernel(iter, [inv_b]GPU_LAMBDA(scalar_t a) -> scalar_t {
3737
return a * inv_b;
3838
});
3939
});
4040
} else {
41-
AT_DISPATCH_ALL_TYPES_AND(kHalf, iter.dtype(), "div_cuda", [&]() {
41+
AT_DISPATCH_ALL_TYPES_AND(kHalf, iter.common_dtype(), "div_cuda", [&]() {
4242
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
4343
return a / b;
4444
});
@@ -47,13 +47,13 @@ void div_kernel_cuda(TensorIterator& iter) {
4747
}
4848

4949
void mul_kernel_cuda(TensorIterator& iter) {
50-
if (iter.dtype() == ScalarType::Bool) {
50+
if (iter.common_dtype() == ScalarType::Bool) {
5151
// Workaround for the error: '*' in boolean context, suggest '&&' instead [-Werror=int-in-bool-context]
5252
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(bool a, bool b) -> bool {
5353
return a && b;
5454
});
5555
} else {
56-
AT_DISPATCH_ALL_TYPES_AND(kHalf, iter.dtype(), "mul_cuda", [&]() {
56+
AT_DISPATCH_ALL_TYPES_AND(kHalf, iter.common_dtype(), "mul_cuda", [&]() {
5757
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
5858
return a * b;
5959
});
@@ -62,22 +62,22 @@ void mul_kernel_cuda(TensorIterator& iter) {
6262
}
6363

6464
void atan2_kernel_cuda(TensorIterator& iter) {
65-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "atan2_cuda", [&]() {
65+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "atan2_cuda", [&]() {
6666
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
6767
return ::atan2(a, b);
6868
});
6969
});
7070
}
7171

7272
void logical_xor_kernel_cuda(TensorIterator& iter) {
73-
if (iter.dtype() == ScalarType::Bool) {
73+
if (iter.common_dtype() == ScalarType::Bool) {
7474
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.input_dtype(), "logical_xor_cuda", [&]() {
7575
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
7676
return bool(a) != bool(b);
7777
});
7878
});
7979
} else {
80-
AT_DISPATCH_ALL_TYPES_AND(kHalf, iter.dtype(), "logical_xor_cuda", [&]() {
80+
AT_DISPATCH_ALL_TYPES_AND(kHalf, iter.common_dtype(), "logical_xor_cuda", [&]() {
8181
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
8282
return static_cast<scalar_t>(bool(a) != bool(b));
8383
});
@@ -86,14 +86,14 @@ void logical_xor_kernel_cuda(TensorIterator& iter) {
8686
}
8787

8888
void lt_kernel_cuda(TensorIterator& iter) {
89-
if (iter.dtype() == ScalarType::Bool) {
89+
if (iter.common_dtype() == ScalarType::Bool) {
9090
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.input_dtype(), "lt_cpu", [&]() {
9191
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
9292
return a < b;
9393
});
9494
});
9595
} else {
96-
AT_DISPATCH_ALL_TYPES_AND(kHalf, iter.dtype(), "lt_cpu", [&]() {
96+
AT_DISPATCH_ALL_TYPES_AND(kHalf, iter.common_dtype(), "lt_cpu", [&]() {
9797
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
9898
return a < b;
9999
});
@@ -102,14 +102,14 @@ void lt_kernel_cuda(TensorIterator& iter) {
102102
}
103103

104104
void le_kernel_cuda(TensorIterator& iter) {
105-
if (iter.dtype() == ScalarType::Bool) {
105+
if (iter.common_dtype() == ScalarType::Bool) {
106106
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.input_dtype(), "le_cpu", [&]() {
107107
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
108108
return a <= b;
109109
});
110110
});
111111
} else {
112-
AT_DISPATCH_ALL_TYPES_AND(kHalf, iter.dtype(), "le_cpu", [&]() {
112+
AT_DISPATCH_ALL_TYPES_AND(kHalf, iter.common_dtype(), "le_cpu", [&]() {
113113
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
114114
return a <= b;
115115
});
@@ -118,14 +118,14 @@ void le_kernel_cuda(TensorIterator& iter) {
118118
}
119119

120120
void gt_kernel_cuda(TensorIterator& iter) {
121-
if (iter.dtype() == ScalarType::Bool) {
121+
if (iter.common_dtype() == ScalarType::Bool) {
122122
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.input_dtype(), "gt_cpu", [&]() {
123123
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
124124
return a > b;
125125
});
126126
});
127127
} else {
128-
AT_DISPATCH_ALL_TYPES_AND(kHalf, iter.dtype(), "gt_cpu", [&]() {
128+
AT_DISPATCH_ALL_TYPES_AND(kHalf, iter.common_dtype(), "gt_cpu", [&]() {
129129
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
130130
return a > b;
131131
});
@@ -134,14 +134,14 @@ void gt_kernel_cuda(TensorIterator& iter) {
134134
}
135135

136136
void ge_kernel_cuda(TensorIterator& iter) {
137-
if (iter.dtype() == ScalarType::Bool) {
137+
if (iter.common_dtype() == ScalarType::Bool) {
138138
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.input_dtype(), "ge_cpu", [&]() {
139139
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
140140
return a >= b;
141141
});
142142
});
143143
} else {
144-
AT_DISPATCH_ALL_TYPES_AND(kHalf, iter.dtype(), "ge_cpu", [&]() {
144+
AT_DISPATCH_ALL_TYPES_AND(kHalf, iter.common_dtype(), "ge_cpu", [&]() {
145145
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
146146
return a >= b;
147147
});
@@ -150,14 +150,14 @@ void ge_kernel_cuda(TensorIterator& iter) {
150150
}
151151

152152
void eq_kernel_cuda(TensorIterator& iter) {
153-
if (iter.dtype() == ScalarType::Bool) {
153+
if (iter.common_dtype() == ScalarType::Bool) {
154154
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.input_dtype(), "eq_cpu", [&]() {
155155
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
156156
return a == b;
157157
});
158158
});
159159
} else {
160-
AT_DISPATCH_ALL_TYPES_AND(kHalf, iter.dtype(), "eq_cpu", [&]() {
160+
AT_DISPATCH_ALL_TYPES_AND(kHalf, iter.common_dtype(), "eq_cpu", [&]() {
161161
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
162162
return a == b;
163163
});
@@ -166,14 +166,14 @@ void eq_kernel_cuda(TensorIterator& iter) {
166166
}
167167

168168
void ne_kernel_cuda(TensorIterator& iter) {
169-
if (iter.dtype() == ScalarType::Bool) {
169+
if (iter.common_dtype() == ScalarType::Bool) {
170170
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.input_dtype(), "ne_cpu", [&]() {
171171
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
172172
return a != b;
173173
});
174174
});
175175
} else {
176-
AT_DISPATCH_ALL_TYPES_AND(kHalf, iter.dtype(), "ne_cpu", [&]() {
176+
AT_DISPATCH_ALL_TYPES_AND(kHalf, iter.common_dtype(), "ne_cpu", [&]() {
177177
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
178178
return a != b;
179179
});

aten/src/ATen/native/cuda/Loops.cuh

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include <ATen/detail/FunctionTraits.h>
3636
#include <ATen/native/TensorIterator.h>
3737
#include <c10/macros/Macros.h>
38+
#include <c10/util/TypeCast.h>
3839

3940
// Marks a lambda as executable on both the host and device. The __host__
4041
// attribute is important so that we can access static type information from
@@ -116,6 +117,20 @@ invoke(const func_t &f, char *const C10_RESTRICT data[], const index_t strides[]
116117
return invoke_impl<traits>(f, data, strides, i, Indices{});
117118
}
118119

120+
template <typename traits, typename func_t, typename index_t, size_t... I>
121+
C10_HOST_DEVICE typename traits::result_type
122+
invoke_impl(const func_t &f, char *const C10_RESTRICT data[], const index_t strides[], const ScalarType dtypes[], int i,
123+
c10::guts::index_sequence<I...>) {
124+
return f(c10::fetch_and_cast<typename traits::template arg<I>::type>(dtypes[I], data[I] + i * strides[I])...);
125+
}
126+
127+
template <typename func_t, typename index_t, typename traits = function_traits<func_t>>
128+
C10_HOST_DEVICE typename traits::result_type
129+
invoke(const func_t &f, char *const C10_RESTRICT data[], const index_t strides[], const ScalarType dtypes[], int i) {
130+
using Indices = c10::guts::make_index_sequence<traits::arity>;
131+
return invoke_impl<traits>(f, data, strides, dtypes, i, Indices{});
132+
}
133+
119134
template <typename func_t>
120135
void gpu_kernel_impl(TensorIterator& iter, const func_t& f) {
121136
using traits = function_traits<func_t>;
@@ -130,6 +145,10 @@ void gpu_kernel_impl(TensorIterator& iter, const func_t& f) {
130145
data[i] = (char*)iter.data_ptr(i);
131146
}
132147

148+
at::detail::Array<ScalarType, ntensors> dtypes;
149+
for (int i = 0; i < ntensors; i++) {
150+
dtypes[i] = iter.tensor(i).scalar_type();
151+
}
133152

134153
int64_t numel = iter.numel();
135154
if (iter.is_trivial_1d()) {
@@ -138,19 +157,35 @@ void gpu_kernel_impl(TensorIterator& iter, const func_t& f) {
138157
for (int i = 0; i < ntensors; i++) {
139158
strides[i] = inner_strides[i];
140159
}
141-
142160

143-
launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) {
144-
arg0_t* out = (arg0_t*)(data[0] + strides[0] * idx);
145-
*out = invoke(f, &data.data[1], &strides.data[1], idx);
146-
});
161+
if (iter.needs_dynamic_casting()) {
162+
launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) {
163+
void* out = data[0] + strides[0] * idx;
164+
arg0_t result = invoke(f, &data.data[1], &strides.data[1], &dtypes.data[1], idx);
165+
c10::cast_and_store<arg0_t>(dtypes[0], out, result);
166+
});
167+
} else {
168+
launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) {
169+
arg0_t* out = (arg0_t*)(data[0] + strides[0] * idx);
170+
*out = invoke(f, &data.data[1], &strides.data[1], idx);
171+
});
172+
}
147173
} else {
148174
auto offset_calc = make_offset_calculator<traits::arity + 1>(iter);
149-
launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) {
150-
auto offsets = offset_calc.get(idx);
151-
arg0_t* out = (arg0_t*)(data[0] + offsets[0]);
152-
*out = invoke(f, &data.data[1], &offsets.data[1], 1);
153-
});
175+
if (iter.needs_dynamic_casting()) {
176+
launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) {
177+
auto offsets = offset_calc.get(idx);
178+
void* out = data[0] + offsets[0];
179+
arg0_t result = invoke(f, &data.data[1], &offsets.data[1], &dtypes.data[1], 1);
180+
c10::cast_and_store<arg0_t>(dtypes[0], out, result);
181+
});
182+
} else {
183+
launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) {
184+
auto offsets = offset_calc.get(idx);
185+
arg0_t* out = (arg0_t*)(data[0] + offsets[0]);
186+
*out = invoke(f, &data.data[1], &offsets.data[1], 1);
187+
});
188+
}
154189
}
155190
}
156191

@@ -174,7 +209,6 @@ void gpu_kernel(TensorIterator& iter, const func_t& f) {
174209
}
175210

176211
gpu_kernel_impl(iter, f);
177-
iter.cast_outputs();
178212
}
179213

180214
template <typename func_t>

0 commit comments

Comments
 (0)