Skip to content

Commit 2f1477d

Browse files
dskhudiafacebook-github-bot
authored andcommitted
Minor changes in initialization of dilation (#126)
Summary: Pull Request resolved: #126 Default value for dilation is in function definition itself. Reviewed By: protonu Differential Revision: D17371791 fbshipit-source-id: c3430dfa3faccf549dc066aa8dcd422b910dbcaa
1 parent c8cac64 commit 2f1477d

File tree

1 file changed

+21
-18
lines changed

1 file changed

+21
-18
lines changed

include/fbgemm/ConvUtils.h

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,24 @@
88

99
#include <array>
1010
#include <string>
11+
#include <type_traits>
1112

1213
namespace fbgemm {
1314

15+
template <int N, int... Vals>
16+
constexpr
17+
typename std::enable_if<N == sizeof...(Vals), std::array<int, N>>::type
18+
array_of_ones() {
19+
return std::array<int, N>{{Vals...}};
20+
}
21+
22+
template <int N, int... Vals>
23+
constexpr
24+
typename std::enable_if<N != sizeof...(Vals), std::array<int, N>>::type
25+
array_of_ones() {
26+
return array_of_ones<N, Vals..., 1>();
27+
}
28+
1429
/**
1530
* @brief A struct to conveniently store all convolution parameters.
1631
*/
@@ -34,7 +49,6 @@ struct conv_param_t {
3449

3550
/**
3651
* @brief Constructor for initializing the convolution parameters.
37-
* TODO: Dilation is not handled correctly.
3852
*/
3953
conv_param_t(
4054
int mb,
@@ -45,7 +59,7 @@ struct conv_param_t {
4559
std::array<int, SPATIAL_DIM> k,
4660
std::array<int, SPATIAL_DIM> strd,
4761
std::array<int, SPATIAL_DIM * 2> pd,
48-
std::array<int, SPATIAL_DIM> dilations = {})
62+
std::array<int, SPATIAL_DIM> dilations = array_of_ones<SPATIAL_DIM>())
4963
: MB(mb),
5064
IC(ic),
5165
OC(oc),
@@ -66,17 +80,6 @@ struct conv_param_t {
6680
" does not divide number of output channels = " + std::to_string(oc));
6781
}
6882

69-
bool dilation_unset = true;
70-
for (int d = 0; d < SPATIAL_DIM; ++d) {
71-
if (dilation[d] != 0) {
72-
dilation_unset = false;
73-
break;
74-
}
75-
}
76-
if (dilation_unset) {
77-
dilation.fill(1);
78-
}
79-
8083
for (int d = 0; d < SPATIAL_DIM; ++d) {
8184
IN_DIMP[d] = IN_DIM[d] + pad[d] + pad[SPATIAL_DIM + d];
8285
OUT_DIM[d] = (IN_DIMP[d] - dilation[d] * (K[d] - 1) - 1) / stride[d] + 1;
@@ -115,14 +118,14 @@ struct conv_param_t {
115118
}
116119
for (int d = 0; d < SPATIAL_DIM * 2; ++d) {
117120
out += "pad_" + dim_string[3 - SPATIAL_DIM + (d % SPATIAL_DIM)] + ":" +
118-
std::to_string(pad[d]);
119-
if (d < SPATIAL_DIM * 2 - 1) {
120-
out += ", ";
121-
}
121+
std::to_string(pad[d]) + ", ";
122122
}
123123
for (int d = 0; d < SPATIAL_DIM; ++d) {
124124
out += "dilation_" + dim_string[3 - SPATIAL_DIM + d] + ":" +
125-
std::to_string(dilation[d]) + ", ";
125+
std::to_string(dilation[d]);
126+
if (d < SPATIAL_DIM - 1) {
127+
out += ", ";
128+
}
126129
}
127130
} else {
128131
for (int d = 0; d < SPATIAL_DIM; ++d) {

0 commit comments

Comments
 (0)