88
99#include < array>
1010#include < string>
11+ #include < type_traits>
1112
1213namespace fbgemm {
1314
15+ template <int N, int ... Vals>
16+ constexpr
17+ typename std::enable_if<N == sizeof ...(Vals), std::array<int , N>>::type
18+ array_of_ones () {
19+ return std::array<int , N>{{Vals...}};
20+ }
21+
22+ template <int N, int ... Vals>
23+ constexpr
24+ typename std::enable_if<N != sizeof ...(Vals), std::array<int , N>>::type
25+ array_of_ones () {
26+ return array_of_ones<N, Vals..., 1 >();
27+ }
28+
1429/* *
1530 * @brief A struct to conveniently store all convolution parameters.
1631 */
@@ -34,7 +49,6 @@ struct conv_param_t {
3449
3550 /* *
3651 * @brief Constructor for initializing the convolution parameters.
37- * TODO: Dilation is not handled correctly.
3852 */
3953 conv_param_t (
4054 int mb,
@@ -45,7 +59,7 @@ struct conv_param_t {
4559 std::array<int , SPATIAL_DIM> k,
4660 std::array<int , SPATIAL_DIM> strd,
4761 std::array<int , SPATIAL_DIM * 2 > pd,
48- std::array<int , SPATIAL_DIM> dilations = {} )
62+ std::array<int , SPATIAL_DIM> dilations = array_of_ones<SPATIAL_DIM>() )
4963 : MB(mb),
5064 IC (ic),
5165 OC(oc),
@@ -66,17 +80,6 @@ struct conv_param_t {
6680 " does not divide number of output channels = " + std::to_string (oc));
6781 }
6882
69- bool dilation_unset = true ;
70- for (int d = 0 ; d < SPATIAL_DIM; ++d) {
71- if (dilation[d] != 0 ) {
72- dilation_unset = false ;
73- break ;
74- }
75- }
76- if (dilation_unset) {
77- dilation.fill (1 );
78- }
79-
8083 for (int d = 0 ; d < SPATIAL_DIM; ++d) {
8184 IN_DIMP[d] = IN_DIM[d] + pad[d] + pad[SPATIAL_DIM + d];
8285 OUT_DIM[d] = (IN_DIMP[d] - dilation[d] * (K[d] - 1 ) - 1 ) / stride[d] + 1 ;
@@ -115,14 +118,14 @@ struct conv_param_t {
115118 }
116119 for (int d = 0 ; d < SPATIAL_DIM * 2 ; ++d) {
117120 out += " pad_" + dim_string[3 - SPATIAL_DIM + (d % SPATIAL_DIM)] + " :" +
118- std::to_string (pad[d]);
119- if (d < SPATIAL_DIM * 2 - 1 ) {
120- out += " , " ;
121- }
121+ std::to_string (pad[d]) + " , " ;
122122 }
123123 for (int d = 0 ; d < SPATIAL_DIM; ++d) {
124124 out += " dilation_" + dim_string[3 - SPATIAL_DIM + d] + " :" +
125- std::to_string (dilation[d]) + " , " ;
125+ std::to_string (dilation[d]);
126+ if (d < SPATIAL_DIM - 1 ) {
127+ out += " , " ;
128+ }
126129 }
127130 } else {
128131 for (int d = 0 ; d < SPATIAL_DIM; ++d) {
0 commit comments