Skip to content

Commit 4c23c34

Browse files
ifedanfacebook-github-bot
authored andcommitted
Computing var/stddev and mean at the same time (#18731)
Summary: The current variance kernels compute mean at the same time. Many times we want both statistics together, so it seems reasonable to have a kwarg/function that allows us to get both values without launching an extra kernel. Pull Request resolved: #18731 Differential Revision: D14726082 Pulled By: ifedan fbshipit-source-id: 473cba0227b69eb2240dca5e61a8f4366df0e029
1 parent 08bdd69 commit 4c23c34

19 files changed

+539
-56
lines changed

aten/src/ATen/core/aten_interned_strings.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,8 @@ _(aten, cosh) \
257257
_(aten, cosine_embedding_loss) \
258258
_(aten, cosine_similarity) \
259259
_(aten, cross) \
260+
_(aten, std_mean) \
261+
_(aten, var_mean) \
260262
_(aten, ctc_loss) \
261263
_(aten, cudnn_affine_grid_generator) \
262264
_(aten, cudnn_affine_grid_generator_backward) \
@@ -905,6 +907,8 @@ _(attr, padding_value) \
905907
_(attr, params) \
906908
_(attr, pdist) \
907909
_(attr, cdist) \
910+
_(attr, std_mean) \
911+
_(attr, var_mean) \
908912
_(attr, periodic) \
909913
_(attr, pivot) \
910914
_(attr, pivots) \

aten/src/ATen/native/ReduceOps.cpp

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,41 @@ static std::unique_ptr<TensorIterator> make_reduction(
114114
return TensorIterator::reduce_op(viewed_result, self.to(dtype));
115115
}
116116

117+
static std::unique_ptr<TensorIterator> make_reduction(
118+
const char* name, Tensor& result1, Tensor& result2, const Tensor& self, IntArrayRef dim,
119+
bool keepdim, ScalarType dtype)
120+
{
121+
// check that result type and dtype match if provided
122+
for (const Tensor *t: {&result1, &result2}) {
123+
const Tensor& result = *t;
124+
AT_CHECK(
125+
!result.defined() || result.type().scalarType() == dtype,
126+
name, ": provided dtype must match dtype of result. Got ",
127+
toString(result.type().scalarType()),
128+
" and ",
129+
toString(dtype),
130+
".");
131+
}
132+
133+
int64_t ndim = self.dim();
134+
DimMask mask = make_dim_mask(dim, ndim);
135+
allocate_reduction_result(result1, self, mask, keepdim, dtype);
136+
auto viewed_result1 = review_reduce_result(result1, ndim, mask, keepdim);
137+
138+
allocate_reduction_result(result2, self, mask, keepdim, dtype);
139+
auto viewed_result2 = review_reduce_result(result2, ndim, mask, keepdim);
140+
141+
// special case for type promotion in mixed precision, improves computational
142+
// efficiency.
143+
// We don't generalize this to common mismatched input/output types to avoid cross
144+
// product of templated kernel launches.
145+
if (self.type().scalarType() == dtype ||
146+
(self.is_cuda() && self.type().scalarType() == kHalf && dtype == kFloat)) {
147+
return TensorIterator::reduce_op(viewed_result1, viewed_result2, self);
148+
}
149+
return TensorIterator::reduce_op(viewed_result1, viewed_result2, self.to(dtype));
150+
}
151+
117152
static inline int64_t n_dim_size(const Tensor& self, IntArrayRef dim) {
118153
int64_t numel = 1;
119154
for (auto d : dim) {
@@ -611,6 +646,68 @@ static Tensor &std_var_out(Tensor &result, const Tensor &self, IntArrayRef dim,
611646
return result;
612647
}
613648

649+
static std::tuple<Tensor&,Tensor&> std_var_mean_out(const char* fname, Tensor &result1, Tensor &result2, const Tensor &self, IntArrayRef dim, bool unbiased, bool keepdim, bool take_sqrt) {
650+
AT_ASSERT(result1.defined() && result2.defined());
651+
AT_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA,
652+
fname, " only support CPU and CUDA backend, got: ", toString(self.type().backend()));
653+
AT_CHECK(at::isFloatingType(self.type().scalarType()), fname, " only support floating-point dtypes");
654+
AT_CHECK(result1.type().scalarType() == result2.type().scalarType(),
655+
"provided by result1 dtype must match dtype of result2. Got ",
656+
toString(result1.type().scalarType()),
657+
" and ",
658+
toString(result2.type().scalarType()),
659+
".");
660+
ScalarType dtype = get_dtype(result1, self, {}, true);
661+
auto iter = make_reduction(fname, result1, result2, self, dim, keepdim, dtype);
662+
if (iter->numel() == 0) {
663+
result1.fill_(NAN);
664+
result2.fill_(NAN);
665+
} else {
666+
std_var_stub(iter->device_type(), *iter, unbiased, take_sqrt);
667+
}
668+
return std::tuple<Tensor&, Tensor&>(result1, result2);
669+
}
670+
671+
std::tuple<Tensor&,Tensor&> var_mean_out(Tensor &result1, Tensor &result2, const Tensor &self, IntArrayRef dim, bool unbiased, bool keepdim) {
672+
return std_var_mean_out("var_mean", result1, result2, self, dim, unbiased, keepdim, false);
673+
}
674+
675+
std::tuple<Tensor&,Tensor&> std_mean_out(Tensor &result1, Tensor &result2, const Tensor &self, IntArrayRef dim, bool unbiased, bool keepdim) {
676+
return std_var_mean_out("std_mean", result1, result2, self, dim, unbiased, keepdim, true);
677+
}
678+
679+
std::tuple<Tensor&,Tensor&> var_mean_out(Tensor &result1, Tensor &result2, const Tensor &self, bool unbiased) {
680+
return std_var_mean_out("var_mean", result1, result2, self, {}, unbiased, false, false);
681+
}
682+
683+
std::tuple<Tensor&,Tensor&> std_mean_out(Tensor &result1, Tensor &result2, const Tensor &self, bool unbiased) {
684+
return std_var_mean_out("std_mean", result1, result2, self, {}, unbiased, false, true);
685+
}
686+
687+
std::tuple<Tensor,Tensor> var_mean(const Tensor& self, IntArrayRef dim, bool unbiased, bool keepdim) {
688+
Tensor result1 = at::empty({0}, self.options());
689+
Tensor result2 = at::empty({0}, self.options());
690+
return at::native::var_mean_out(result1, result2, self, dim, unbiased, keepdim);
691+
}
692+
693+
std::tuple<Tensor,Tensor> std_mean(const Tensor& self, IntArrayRef dim, bool unbiased, bool keepdim) {
694+
Tensor result1 = at::empty({0}, self.options());
695+
Tensor result2 = at::empty({0}, self.options());
696+
return at::native::std_mean_out(result1, result2, self, dim, unbiased, keepdim);
697+
}
698+
699+
std::tuple<Tensor,Tensor> std_mean(const Tensor& self, bool unbiased) {
700+
Tensor result1 = at::empty({0}, self.options());
701+
Tensor result2 = at::empty({0}, self.options());
702+
return at::native::std_mean_out(result1, result2, self, unbiased);
703+
}
704+
705+
std::tuple<Tensor,Tensor> var_mean(const Tensor& self, bool unbiased) {
706+
Tensor result1 = at::empty({0}, self.options());
707+
Tensor result2 = at::empty({0}, self.options());
708+
return at::native::var_mean_out(result1, result2, self, unbiased);
709+
}
710+
614711
Tensor var(const Tensor& self, bool unbiased) {
615712
TORCH_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA,
616713
"var only supports CPU AND CUDA backend, got: ", toString(self.type().backend()));

aten/src/ATen/native/SharedReduceOps.h

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66
#if defined(__CUDACC__)
77
#include <THC/THCDeviceUtils.cuh>
88
#include <ATen/native/cuda/DeviceSqrt.cuh>
9+
#include <thrust/tuple.h>
910
#elif defined(__HIPCC__)
1011
#include <THH/THHDeviceUtils.cuh>
1112
#include <ATen/native/hip/DeviceSqrt.cuh>
13+
#include <thrust/tuple.h>
1214
#else
1315
#include <cmath>
1416
#define device_sqrt std::sqrt
@@ -42,7 +44,7 @@ struct WelfordData {
4244
};
4345

4446

45-
template <typename scalar_t, typename acc_scalar_t, typename index_t, typename combine_t>
47+
template <typename scalar_t, typename acc_scalar_t, typename index_t, typename combine_t, typename res_t>
4648
struct WelfordOps {
4749
bool unbiased;
4850
bool take_sqrt;
@@ -80,12 +82,18 @@ struct WelfordOps {
8082
new_count
8183
};
8284
}
83-
inline C10_DEVICE scalar_t project(acc_t acc) const {
85+
inline C10_DEVICE res_t project(acc_t acc) const {
86+
auto mean = acc.mean;
8487
combine_t divisor = unbiased ? (acc.nf - 1) : acc.nf;
8588
auto ret = (divisor > 0) ?
8689
(take_sqrt ? device_sqrt(acc.m2 / divisor) : (acc.m2 / divisor))
8790
: NAN;
88-
return (scalar_t) ret;
91+
#if defined(__CUDACC__) || defined(__HIPCC__)
92+
thrust::tuple<scalar_t, scalar_t> results((scalar_t) ret, (scalar_t) mean);
93+
#else
94+
std::tuple<scalar_t, scalar_t> results{(scalar_t) ret, (scalar_t) mean};
95+
#endif
96+
return results;
8997
}
9098
#if defined(__CUDACC__) || defined(__HIPCC__)
9199
inline __device__ acc_t warp_shfl_down(acc_t acc, int offset) const {

aten/src/ATen/native/TensorIterator.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,28 @@ std::unique_ptr<TensorIterator> TensorIterator::reduce_op(Tensor& out, const Ten
505505
return builder.build();
506506
}
507507

508+
std::unique_ptr<TensorIterator> TensorIterator::reduce_op(Tensor& out1, Tensor& out2, const Tensor& a) {
509+
AT_ASSERT(out1.defined());
510+
AT_ASSERT(out2.defined());
511+
AT_CHECK((!a.is_cuda() && !out1.is_cuda() && !out2.is_cuda()) || (a.device() == out1.device() && out1.device() == out2.device()),
512+
"reduce_op(): expected input and both outputs to be on same device, but input is on ", a.device(),
513+
", output1 is on ", out1.device(), " and output2 is on", out2.device());
514+
AT_CHECK(out1.dim() == out2.dim(), "reduce_op(): expected both outputs to have same number of dims, but output1 has ", out1.dim(),
515+
" and output2 has ", out2.dim());
516+
AT_CHECK(out1.sizes() == out2.sizes(), "reduce_op(): expected both outputs to have same sizes, but output1 has ", out1.sizes(),
517+
" and output2 has ", out2.sizes());
518+
AT_CHECK(out1.strides() == out2.strides(), "reduce_op(): expected both outputs to have same strides, but output1 has ", out1.strides(),
519+
" and output2 has ", out2.strides());
520+
auto builder = TensorIterator::Builder();
521+
builder.add_output(out1);
522+
builder.add_output(out2);
523+
builder.add_input(a);
524+
builder.iter_->promote_gpu_output_dtypes_ = true;
525+
builder.iter_->resize_outputs_ = false;
526+
builder.iter_->is_reduction_ = true;
527+
return builder.build();
528+
}
529+
508530
void TensorIterator::mark_outputs() {
509531
for (int i = 0; i < num_outputs_; i++) {
510532
operands_[i].is_output = true;

aten/src/ATen/native/TensorIterator.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,11 +148,14 @@ struct CAFFE2_API TensorIterator {
148148
static std::unique_ptr<TensorIterator> unary_op(Tensor& out, const Tensor& a);
149149
static std::unique_ptr<TensorIterator> nullary_op(Tensor& out);
150150
static std::unique_ptr<TensorIterator> reduce_op(Tensor& out, const Tensor& a);
151+
static std::unique_ptr<TensorIterator> reduce_op(Tensor& out1, Tensor& out2, const Tensor& a);
151152

152153
int ndim() const { return shape_.size(); }
153154
IntArrayRef shape() const { return shape_; }
154155
int64_t numel() const;
155156
int ntensors() const { return operands_.size(); }
157+
int noutputs() const { return num_outputs_; }
158+
int ninputs() const { return ntensors() - noutputs(); }
156159

157160
/// number of elements in the output operand. this is the same as numel() for
158161
/// operations that are not reductions.
@@ -182,6 +185,11 @@ struct CAFFE2_API TensorIterator {
182185
return operands_[arg].tensor;
183186
}
184187

188+
Tensor input(int arg=0) const {
189+
AT_ASSERT(arg >= 0 && arg < ntensors() - num_outputs_);
190+
return operands_[num_outputs_ + arg].tensor;
191+
}
192+
185193
/// Removes an operand from this iterator
186194
void remove_operand(int arg);
187195
/// Removes a dimension from this iterator

aten/src/ATen/native/TensorIteratorReduce.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,13 @@ void TensorIterator::parallel_reduce(const loop2d_t& loop) {
2727
}
2828

2929
static bool use_two_pass_reduction(TensorIterator& iter) {
30-
return iter.tensor(0).numel() == 1;
30+
return iter.output(0).numel() == 1;
3131
}
3232

3333
static void two_pass_reduction(TensorIterator& iter, const loop2d_t& loop) {
3434
int max_threads = at::get_num_threads();
3535

36-
auto& dst = iter.tensor(0);
36+
auto dst = iter.output(0);
3737
auto buffer_shape = DimVector(dst.sizes());
3838
buffer_shape.insert(buffer_shape.begin(), max_threads);
3939
auto buffer = at::empty(buffer_shape, dst.options());
@@ -47,7 +47,7 @@ static void two_pass_reduction(TensorIterator& iter, const loop2d_t& loop) {
4747
auto slice = buffer[thread_num];
4848
slice.copy_(dst);
4949

50-
auto sub_iter = TensorIterator::reduce_op(slice, iter.tensor(1));
50+
auto sub_iter = TensorIterator::reduce_op(slice, iter.input(0));
5151
sub_iter->serial_for_each(loop, {begin, end});
5252
});
5353

@@ -117,13 +117,14 @@ static void parallel_dim_reduction(TensorIterator& iter, const loop2d_t& loop) {
117117
}
118118

119119
void TensorIterator::foreach_reduced_elt(const loop_subiter_t &loop, bool parallelize) {
120-
AT_ASSERT(ntensors() == 2 && num_outputs_ == 1);
120+
AT_ASSERT(ninputs() == 1);
121+
AT_ASSERT(noutputs() >= 1);
121122

122123
auto shape = this->shape();
123-
if (tensor(0).numel() == 0) {
124+
if (output(0).numel() == 0) {
124125
return;
125126
}
126-
if (tensor(0).numel() == 1) {
127+
if (output(0).numel() == 1) {
127128
loop(*this);
128129
}
129130
else if (numel() < at::internal::GRAIN_SIZE || at::get_num_threads() == 1 ||

aten/src/ATen/native/cpu/Reduce.h

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,44 @@ static inline bool is_outer_reduction(const int64_t* strides) {
2525
strides[3] == sizeof(typename traits::arg2_t);
2626
}
2727

28+
template<typename traits, typename res_t>
29+
static void set_result(const int index, const res_t result, const TensorIterator &iter, const int num_outputs) {
30+
static_assert(std::is_same<res_t, typename traits::arg2_t>::value, "data types must match");
31+
if (index < num_outputs) {
32+
char *out = (char *) iter.data_ptr(index);
33+
*(res_t *) out = result;
34+
}
35+
}
36+
37+
template<typename traits, typename res_t>
38+
static void set_results(const res_t result, const TensorIterator &iter, const int num_outputs) {
39+
AT_ASSERT(num_outputs == 1);
40+
set_result<traits>(0, result, iter, num_outputs);
41+
}
42+
43+
template<typename traits, std::size_t i = 0, typename... tuple_t>
44+
static inline typename std::enable_if<i == sizeof...(tuple_t), std::size_t>::type
45+
for_each_in_tuple(const std::tuple<tuple_t...>& t, const TensorIterator &iter, const int num_outputs) {
46+
return i;
47+
}
48+
49+
template<typename traits, std::size_t i = 0, typename... tuple_t>
50+
static inline typename std::enable_if<i < sizeof...(tuple_t), std::size_t>::type
51+
for_each_in_tuple(const std::tuple<tuple_t...>& t, const TensorIterator &iter, const int num_outputs) {
52+
if (i < num_outputs) {
53+
set_result<traits>(i, std::get<i>(t), iter, num_outputs);
54+
return for_each_in_tuple<traits, i + 1, tuple_t...>(t, iter, num_outputs);
55+
}
56+
return i;
57+
}
58+
59+
template<typename traits, typename... res_t>
60+
static void set_results(const std::tuple<res_t...>& result, const TensorIterator &iter, const int num_outputs) {
61+
AT_ASSERT(num_outputs >= 1);
62+
std::size_t result_size = for_each_in_tuple<traits>(result, iter, num_outputs);
63+
AT_ASSERT(num_outputs == result_size);
64+
}
65+
2866
template <typename T, typename... Args>
2967
struct all_same : c10::guts::conjunction<
3068
std::is_same<T, Args>...
@@ -64,7 +102,7 @@ void binary_kernel_reduce(TensorIterator& iter, ops_t ops, init_t init) {
64102
using c_traits = binary_function_traits<cf_t>;
65103
using p_traits = unary_function_traits<pf_t>;
66104
using acc_t = typename p_traits::arg1_t;
67-
using data_t = typename p_traits::result_type;
105+
using data_t = typename r_traits::arg2_t;
68106
static_assert(
69107
all_same<
70108
acc_t,
@@ -75,19 +113,17 @@ void binary_kernel_reduce(TensorIterator& iter, ops_t ops, init_t init) {
75113
typename c_traits::arg2_t,
76114
typename c_traits::result_type>::value,
77115
"all accumulate types must match");
78-
static_assert(
79-
std::is_same<data_t, typename r_traits::arg2_t>::value,
80-
"all data types must match");
81116
static_assert(
82117
std::is_default_constructible<acc_t>::value,
83118
"the accumulate type must be default-constructible"
84119
);
85-
iter.foreach_reduced_elt([&](TensorIterator &sub_iter) {
86-
auto reduction_body = [&](acc_t acc, int64_t begin, int64_t end) -> acc_t {
87-
sub_iter.serial_for_each([&acc, &ops](int ntensors, char** data, const int64_t* strides, int64_t size) {
88-
AT_ASSERT(ntensors == 2);
89-
char *in = data[1];
90-
int64_t stride = strides[1];
120+
const int num_outputs = iter.noutputs();
121+
iter.foreach_reduced_elt([&ops, &init, num_outputs](TensorIterator &sub_iter) {
122+
auto reduction_body = [&ops, &sub_iter, num_outputs](acc_t acc, int64_t begin, int64_t end) -> acc_t {
123+
sub_iter.serial_for_each([&acc, &ops, num_outputs](int ntensors, char** data, const int64_t* strides, int64_t size) {
124+
AT_ASSERT(ntensors - num_outputs == 1);
125+
char *in = data[ntensors - 1];
126+
int64_t stride = strides[ntensors - 1];
91127
for (int64_t i = 0; i < size; ++i) {
92128
acc = ops.reduce(acc, *(data_t*)in);
93129
in += stride;
@@ -118,8 +154,7 @@ void binary_kernel_reduce(TensorIterator& iter, ops_t ops, init_t init) {
118154
total_acc = ops.combine(total_acc, buffer[i]);
119155
}
120156
}
121-
char *out = (char *)sub_iter.data_ptr(0);
122-
*(data_t*)out = ops.project(total_acc);
157+
set_results<r_traits>(ops.project(total_acc), sub_iter, num_outputs);
123158
});
124159
}
125160

aten/src/ATen/native/cpu/ReduceOpsKernel.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ static void std_var_kernel_impl(TensorIterator &iter, bool unbiased, bool take_s
3838
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "std_cpu", [&] {
3939
binary_kernel_reduce(
4040
iter,
41-
WelfordOps<scalar_t, double, int64_t, double> { unbiased, take_sqrt },
41+
WelfordOps<scalar_t, double, int64_t, double, std::tuple<scalar_t, scalar_t>> { unbiased, take_sqrt },
4242
WelfordData<double, int64_t, double>()
4343
);
4444
});

0 commit comments

Comments
 (0)