44#include " ATen/NativeFunctions.h"
55#include " ATen/WrapDimUtils.h"
66#include " ATen/WrapDimUtilsMulti.h"
7+ #include " ReduceOpsUtils.h"
78#include " cpu/ReduceOpsKernel.h"
89
910#include < algorithm>
1011#include < functional>
12+ #include < limits>
1113#include < numeric>
1214#include < vector>
1315#include < map>
@@ -94,10 +96,12 @@ static inline Tensor mean(const Tensor &self, optional<ScalarType> dtype) {
9496 " Can only calculate the mean of floating types. Got " ,
9597 at::toString (scalarType),
9698 " instead." );
97- Tensor result = at::native::sum (self);
98- if (self.numel () > 0 )
99- result.div_ (self.numel ());
100- return result;
99+ if (self.numel () > 0 ) {
100+ Tensor result = at::native::sum (self);
101+ return result.div_ (self.numel ());
102+ } else {
103+ return self.type ().scalarTensor (std::numeric_limits<double >::quiet_NaN ());
104+ }
101105}
102106
103107Tensor mean (const Tensor &self, ScalarType dtype) {
@@ -154,32 +158,6 @@ Tensor _prod_cpu(const Tensor &self) {
154158
155159// DIM REDUCE #################################################################
156160
157- static bool _dimreduce_return_trivial (Tensor &result, const Tensor &self,
158- int64_t ident) {
159- if (self.numel () == 1 && self.ndimension () == 0 ) {
160- result.resize_ ({});
161- result.fill_ (self);
162- return true ;
163- }
164- // Return identity
165- if (self.numel () == 0 && self.ndimension () == 1 ) {
166- result.resize_ ({0 });
167- result.fill_ (ident);
168- return true ;
169- }
170- return false ;
171- }
172-
173- static Tensor &_dimreduce_setup (Tensor &result, const Tensor &self,
174- int64_t dim) {
175- IntList self_sizes = self.sizes ();
176- std::vector<int64_t > result_sizes;
177- result_sizes.insert (result_sizes.end (), self_sizes.begin (), self_sizes.end ());
178- result_sizes[dim] = 1 ;
179- result.resize_ (result_sizes);
180- return result;
181- }
182-
183161static inline Tensor &mean_out (Tensor &result, const Tensor &self, int64_t dim,
184162 bool keepdim, optional<ScalarType> dtype) {
185163 ScalarType scalarType = result.type ().scalarType ();
@@ -192,7 +170,12 @@ static inline Tensor &mean_out(Tensor &result, const Tensor &self, int64_t dim,
192170 result, self.toType (result.type ().scalarType ()), dim, keepdim);
193171 if (result.numel () > 0 && self.ndimension () > 0 ) {
194172 int64_t numel = self.size (dim);
195- result.div_ (numel);
173+ if (numel > 0 ) {
174+ result.div_ (numel);
175+ } else {
176+ // NumPy equivalent
177+ result.fill_ (std::numeric_limits<double >::quiet_NaN ());
178+ }
196179 }
197180 return result;
198181}
@@ -235,7 +218,7 @@ Tensor& sum_out(Tensor& result, const Tensor& self, IntList dim, ScalarType dtyp
235218Tensor &_sum_out_cpu (Tensor &result, const Tensor &self, int64_t dim_,
236219 bool keepdim) {
237220 int64_t dim = maybe_wrap_dim (dim_, self.dim ());
238- if (_dimreduce_return_trivial (result, self, 0 ))
221+ if (_dimreduce_return_trivial (result, self, 0 , dim, keepdim ))
239222 return result;
240223 if (self.is_contiguous () && result.is_contiguous ()) {
241224 _dimreduce_setup (result, self, dim);
@@ -273,7 +256,7 @@ Tensor& prod_out(Tensor& result, const Tensor& self, int64_t dim, ScalarType dty
273256Tensor &_prod_out_cpu (Tensor &result, const Tensor &self, int64_t dim_,
274257 bool keepdim) {
275258 int64_t dim = maybe_wrap_dim (dim_, self.dim ());
276- if (_dimreduce_return_trivial (result, self, 1 ))
259+ if (_dimreduce_return_trivial (result, self, 1 , dim, keepdim ))
277260 return result;
278261 if (self.is_contiguous () && result.is_contiguous ()) {
279262 _dimreduce_setup (result, self, dim);
@@ -294,7 +277,12 @@ static inline Tensor mean(const Tensor &self, int64_t dim, bool keepdim, optiona
294277 Tensor result = at::native::sum (self, dim, keepdim);
295278 if (result.numel () > 0 && self.ndimension () > 0 ) {
296279 int64_t numel = self.size (dim);
297- result.div_ (numel);
280+ if (numel > 0 ) {
281+ result.div_ (numel);
282+ } else {
283+ // NumPy equivalent
284+ result.fill_ (std::numeric_limits<double >::quiet_NaN ());
285+ }
298286 }
299287 return result;
300288}
@@ -357,10 +345,15 @@ Tensor _prod(const Tensor &self, int64_t dim_, bool keepdim) {
357345
358346Tensor& logsumexp_out (Tensor& result, const Tensor &self, int64_t dim_, bool keepdim) {
359347 int64_t dim = maybe_wrap_dim (dim_, self.dim ());
360- auto maxes = at::max_values (self, dim, true );
361- result = at::where ((maxes == INFINITY).__or__ (maxes == -INFINITY),
362- maxes,
363- maxes + at::log (at::sum (at::exp (self - maxes), dim, true )));
348+ // can't take max of empty tensor.
349+ if (self.numel () != 0 ) {
350+ auto maxes = at::max_values (self, dim, true );
351+ result = at::where ((maxes == INFINITY).__or__ (maxes == -INFINITY),
352+ maxes,
353+ maxes + at::log (at::sum (at::exp (self - maxes), dim, true )));
354+ } else {
355+ result = at::log (at::sum (at::exp (self), dim, true ));
356+ }
364357 if (! keepdim)
365358 result.squeeze_ (dim);
366359 return result;
@@ -588,4 +581,89 @@ Tensor& _sum_out(Tensor &result, const Tensor &self, IntList dims, bool keepdim)
588581 return reduce_multi_associative_out<_sum, _sum_out>(result, self, dims, keepdim);
589582}
590583
584+ Tensor norm (const Tensor& self, Scalar p, int64_t dim, bool keepdim) {
585+ Tensor result = self.type ().tensor ();
586+ return at::native::norm_out (result, self, p, dim, keepdim);
587+ }
588+
589+ Tensor &norm_out (Tensor &result, const Tensor &self, Scalar p, int64_t dim, bool keepdim) {
590+ AT_CHECK (self.type ().backend () == Backend::CPU || self.type ().backend () == Backend::CUDA,
591+ " norm only supports CPU AND CUDA backend, got: " , at::toString (self.type ().backend ()));
592+ AT_CHECK (at::isFloatingType (self.type ().scalarType ()), " norm only supports floating-point dtypes" );
593+ dim = maybe_wrap_dim (dim, self.dim ());
594+ if (_dimreduce_return_trivial (result, self, 0 , dim, keepdim)) {
595+ return result;
596+ } else {
597+ return at::_th_norm_out (result, self, p, dim, keepdim);
598+ }
599+ }
600+
601+ Tensor all (const Tensor& self, int64_t dim, bool keepdim) {
602+ Tensor result = self.type ().tensor ();
603+ return at::native::all_out (result, self, dim, keepdim);
604+ }
605+
606+ Tensor &all_out (Tensor &result, const Tensor &self, int64_t dim, bool keepdim) {
607+ AT_CHECK (self.type ().backend () == Backend::CPU || self.type ().backend () == Backend::CUDA,
608+ " all only supports CPU AND CUDA backend, got: " , at::toString (self.type ().backend ()));
609+ AT_CHECK (self.type ().scalarType () == at::ScalarType::Byte, " all only supports torch.uint8 dtype" );
610+ dim = maybe_wrap_dim (dim, self.dim ());
611+ if (_dimreduce_return_trivial (result, self, 1 , dim, keepdim)) {
612+ return result;
613+ } else {
614+ return at::_th_all_out (result, self, dim, keepdim);
615+ }
616+ }
617+
618+ Tensor any (const Tensor& self, int64_t dim, bool keepdim) {
619+ Tensor result = self.type ().tensor ();
620+ return at::native::any_out (result, self, dim, keepdim);
621+ }
622+
623+ Tensor &any_out (Tensor &result, const Tensor &self, int64_t dim, bool keepdim) {
624+ AT_CHECK (self.type ().backend () == Backend::CPU || self.type ().backend () == Backend::CUDA,
625+ " any only supports CPU AND CUDA backend, got: " , at::toString (self.type ().backend ()));
626+ AT_CHECK (self.type ().scalarType () == at::ScalarType::Byte, " any only supports torch.uint8 dtype" );
627+ dim = maybe_wrap_dim (dim, self.dim ());
628+ if (_dimreduce_return_trivial (result, self, 0 , dim, keepdim)) {
629+ return result;
630+ } else {
631+ return at::_th_any_out (result, self, dim, keepdim);
632+ }
633+ }
634+
635+ Tensor var (const Tensor& self, int64_t dim, bool unbiased, bool keepdim) {
636+ Tensor result = self.type ().tensor ();
637+ return at::native::var_out (result, self, dim, unbiased, keepdim);
638+ }
639+
640+ Tensor &var_out (Tensor &result, const Tensor &self, int64_t dim, bool unbiased, bool keepdim) {
641+ AT_CHECK (self.type ().backend () == Backend::CPU || self.type ().backend () == Backend::CUDA,
642+ " var only supports CPU AND CUDA backend, got: " , at::toString (self.type ().backend ()));
643+ AT_CHECK (at::isFloatingType (self.type ().scalarType ()), " var only supports floating-point dtypes" );
644+ dim = maybe_wrap_dim (dim, self.dim ());
645+ if (_dimreduce_return_trivial (result, self, std::numeric_limits<double >::quiet_NaN (), dim, keepdim)) {
646+ return result;
647+ } else {
648+ return at::_th_var_out (result, self, dim, unbiased, keepdim);
649+ }
650+ }
651+
652+ Tensor std (const Tensor& self, int64_t dim, bool unbiased, bool keepdim) {
653+ Tensor result = self.type ().tensor ();
654+ return at::native::std_out (result, self, dim, unbiased, keepdim);
655+ }
656+
657+ Tensor &std_out (Tensor &result, const Tensor &self, int64_t dim, bool unbiased, bool keepdim) {
658+ AT_CHECK (self.type ().backend () == Backend::CPU || self.type ().backend () == Backend::CUDA,
659+ " std only supports CPU AND CUDA backend, got: " , at::toString (self.type ().backend ()));
660+ AT_CHECK (at::isFloatingType (self.type ().scalarType ()), " std only supports floating-point dtypes" );
661+ dim = maybe_wrap_dim (dim, self.dim ());
662+ if (_dimreduce_return_trivial (result, self, std::numeric_limits<double >::quiet_NaN (), dim, keepdim)) {
663+ return result;
664+ } else {
665+ return at::_th_std_out (result, self, dim, unbiased, keepdim);
666+ }
667+ }
668+
591669}} // namespace at::native
0 commit comments