11#include " ATen/ATen.h"
2- #include " ATen/TensorUtils.h"
3- #include " ATen/NativeFunctions.h"
42
5- #include < sstream>
6- #include < vector>
3+ #include " ATen/Error.h"
4+ #include " ATen/NativeFunctions.h"
5+ #include " ATen/TensorUtils.h"
76
7+ #include < tuple>
88
99namespace at { namespace native {
1010
11- static void check1d (const char * name, IntList x) {
12- if (x.size () != 1 ) {
13- std::ostringstream ss;
14- ss << " max_pool1d() argument '" << name << " ' should contain one int (got "
15- << x.size () << " )" ;
16- throw std::runtime_error (ss.str ());
17- }
11+ static void check1d (
12+ const char * function_name,
13+ const char * argument_name,
14+ IntList x) {
15+ AT_CHECK (
16+ x.size () == 1 ,
17+ function_name, " () argument '" , argument_name,
18+ " ' should contain one int (got " , x.size (), " )" );
1819}
1920
2021Tensor adaptive_avg_pool1d (const Tensor & self, IntList output_size) {
2122 checkDim (" adaptive_avg_pool1d" , TensorArg (self, " self" , 1 ), 3 );
22- check1d (" output_size" , output_size);
23+ check1d (" adaptive_avg_pool1d " , " output_size" , output_size);
2324
2425 auto output = at::adaptive_avg_pool2d (
2526 self.unsqueeze (2 ),
@@ -30,7 +31,7 @@ Tensor adaptive_avg_pool1d(const Tensor & self, IntList output_size) {
3031
3132std::tuple<Tensor,Tensor> adaptive_max_pool1d (const Tensor & self, IntList output_size) {
3233 checkDim (" adaptive_max_pool1d" , TensorArg (self, " self" , 1 ), 3 );
33- check1d (" output_size" , output_size);
34+ check1d (" adaptive_max_pool1d " , " output_size" , output_size);
3435
3536 Tensor output, indices;
3637 std::tie (output, indices) = at::adaptive_max_pool2d (
@@ -48,10 +49,10 @@ std::tuple<Tensor,Tensor> max_pool1d(
4849 stride = kernel_size;
4950 }
5051 checkDim (" max_pool1d" , TensorArg (self, " self" , 1 ), 3 );
51- check1d (" kernel_size" , kernel_size);
52- check1d (" stride" , stride);
53- check1d (" padding" , padding);
54- check1d (" dilation" , dilation);
52+ check1d (" max_pool1d " , " kernel_size" , kernel_size);
53+ check1d (" max_pool1d " , " stride" , stride);
54+ check1d (" max_pool1d " , " padding" , padding);
55+ check1d (" max_pool1d " , " dilation" , dilation);
5556
5657 Tensor output, indices;
5758 std::tie (output, indices) = at::max_pool2d (
@@ -65,4 +66,30 @@ std::tuple<Tensor,Tensor> max_pool1d(
6566 return std::make_tuple (output.squeeze (2 ), indices.squeeze (2 ));
6667}
6768
68- }} // namespace at::native
69+ Tensor avg_pool1d (
70+ const Tensor& self,
71+ IntList kernel_size,
72+ IntList stride,
73+ IntList padding,
74+ bool ceil_mode,
75+ bool count_include_pad) {
76+ if (stride.empty ()) {
77+ stride = kernel_size;
78+ }
79+ checkDim (" avg_pool1d" , TensorArg (self, " self" , 1 ), 3 );
80+ check1d (" avg_pool1d" , " kernel_size" , kernel_size);
81+ check1d (" avg_pool1d" , " stride" , stride);
82+ check1d (" avg_pool1d" , " padding" , padding);
83+
84+ auto output = at::avg_pool2d (
85+ self.unsqueeze (2 ),
86+ {1 , kernel_size[0 ]},
87+ {1 , stride[0 ]},
88+ {0 , padding[0 ]},
89+ ceil_mode,
90+ count_include_pad);
91+
92+ return output.squeeze (2 );
93+ }
94+ } // namespace native
95+ } // namespace at
0 commit comments