@@ -25,8 +25,9 @@ namespace native {
2525
2626static void make_offset2bag (const Tensor &offsets, const Tensor &indices,
2727 Tensor &offset2bag) {
28- offset2bag.index_fill_ (0 , offsets, 1 ); // offset2bag = [1 0 1 0 1]
29- offset2bag[0 ] = 0 ; // offset2bag = [0 0 1 0 1]
28+ offset2bag.index_add_ (
29+ 0 , offsets, at::ones_like (offsets)); // offset2bag = [1 0 1 0 1]
30+ offset2bag[0 ] -= 1 ; // offset2bag = [0 0 1 0 1]
3031 offset2bag = offset2bag.cumsum (0 ); // offset2bag = [0 0 1 1 2]
3132}
3233
@@ -65,13 +66,14 @@ static void index_select_add(const Tensor &select_indices,
6566
6667static void make_bag_size (const Tensor &offsets, const Tensor &indices,
6768 const int64_t mode, Tensor &bag_size) {
68- if (mode == 1 ) { // MODE_MEAN
69+ if (mode == 1 || mode == 2 ) {
70+ // Compute this for MODE_MEAN and MODE_MAX (latter needed for backwards)
6971 if (offsets.size (0 ) != 1 ) {
7072 bag_size.slice (0 , 0 , bag_size.size (0 ) - 1 , 1 ) =
7173 offsets.slice (0 , 1 , offsets.size (0 ), 1 ) -
7274 offsets.slice (0 , 0 , offsets.size (0 ) - 1 , 1 );
73- bag_size[-1 ] = indices.size (0 ) - offsets[-1 ];
7475 }
76+ bag_size[-1 ] = indices.size (0 ) - offsets[-1 ];
7577 }
7678}
7779
@@ -83,8 +85,12 @@ static Tensor apply_bag_size(const Tensor &offsets, const Tensor &indices,
8385 auto bag_size_ = indices.size (0 );
8486 output /= bag_size_;
8587 } else {
86- auto bag_size_ =
87- bag_size.toType (output.type ()).unsqueeze (1 ).expand_as (output);
88+ // Avoid dividing by 0 for empty bags.
89+ // Instead we want empty bags to return all 0s
90+ auto bag_size_ = at::max (bag_size, at::ones_like (bag_size))
91+ .toType (output.type ())
92+ .unsqueeze (1 )
93+ .expand_as (output);
8894 output /= bag_size_;
8995 }
9096 }
@@ -113,7 +119,7 @@ static Tensor apply_bag_size_backward(const Tensor &offsets,
113119template <typename scalar_t >
114120std::tuple<Tensor, Tensor, Tensor, Tensor> embedding_bag_cpu_max (
115121 const Tensor& weight, const Tensor &indices, const Tensor& offset2bag, const Tensor& output, const Tensor& bag_size, const Tensor& offsets) {
116-
122+
117123 auto max_indices = at::zeros (indices.type (), {offsets.size (0 ), weight.size (1 )});
118124
119125 int64_t numel = indices.numel ();
@@ -125,9 +131,9 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> embedding_bag_cpu_max(
125131 auto max_indices_stride = max_indices.stride (0 );
126132
127133 auto weight_data = weight.data <scalar_t >();
128- auto output_data = output.data <scalar_t >();
129- auto weight_stride = weight.stride (0 );
130- auto output_stride = output.stride (0 );
134+ auto output_data = output.data <scalar_t >();
135+ auto weight_stride = weight.stride (0 );
136+ auto output_stride = output.stride (0 );
131137
132138 for (int i = 0 ; i < numel; i++) {
133139 auto bag = offset2bag_data[i];
@@ -137,7 +143,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> embedding_bag_cpu_max(
137143 for (int dim = 0 ; dim < dims; dim++) {
138144 auto & current_item = output_data[output_stride * bag + dim];
139145 auto weight_item = weight_data[weight_stride * word_idx + dim];
140-
146+
141147 bool is_first_for_bag = (i == 0 ) || offset2bag_data[i - 1 ] != bag;
142148
143149 if (is_first_for_bag || weight_item > current_item) {
@@ -164,9 +170,19 @@ embedding_bag_cpu(const Tensor &weight, const Tensor &indices__,
164170 checkScalarTypes (" embedding_bag" , weight_arg, {kFloat , kDouble });
165171
166172 auto bag_size = at::zeros (indices.type (), offsets.sizes ());
167- auto offset2bag =
168- at::zeros (indices__.type (), {indices.size (0 )}); // offset2bag = [0 0 0 0 0]
173+ make_bag_size (offsets, indices, mode, bag_size);
174+
175+ // If the last entries are empty, that the last offsets are irrelevant as they
176+ // won't change anything in the assignment of ID -> bag, but index_add would
177+ // throw out of bounds error. So to keep it simple we just add one more
178+ // entry to the end then get rid of it after make_offset2bag.
179+ auto offset2bag = at::zeros (
180+ indices__.type (), {indices.sizes ()[0 ] + 1 }); // offset2bag = [0 0 0 0 0]
181+
169182 make_offset2bag (offsets, indices, offset2bag);
183+
184+ offset2bag.resize_ ({indices.sizes ()[0 ]});
185+
170186 auto output = at::zeros (weight.type (), {offsets.size (0 ), weight.size (1 )});
171187
172188 if (mode == MODE_MEAN || mode == MODE_SUM) {
@@ -175,7 +191,6 @@ embedding_bag_cpu(const Tensor &weight, const Tensor &indices__,
175191 } else if (weight.type ().scalarType () == kDouble ) {
176192 index_select_add<double >(indices, offset2bag, weight, output);
177193 }
178- make_bag_size (offsets, indices, mode, bag_size);
179194 auto ret = apply_bag_size (offsets, indices, mode, output, bag_size);
180195 return std::tuple<Tensor, Tensor, Tensor, Tensor>(ret, offset2bag, bag_size, bag_size);
181196 } else { // MODE_MAX
@@ -190,7 +205,7 @@ embedding_bag_cpu(const Tensor &weight, const Tensor &indices__,
190205Tensor embedding_bag_backward (const Tensor &grad_, const Tensor &indices__,
191206 const Tensor &offsets__,
192207 const Tensor &offset2bag__,
193- const Tensor &bag_size_,
208+ const Tensor &bag_size_,
194209 const Tensor &max_indices_,
195210 int64_t num_weights,
196211 bool scale_grad_by_freq, int64_t mode,
@@ -305,10 +320,14 @@ Tensor embedding_bag_backward_cpu(const Tensor &grad_, const Tensor &indices__,
305320 igwd + ddim * index, 1 );
306321 }
307322 }
308- }
323+ }
309324 } else if (mode == MODE_MAX) {
325+ auto nonempty_max_indices = max_indices_.index_select (0 , bag_size_.nonzero ().view (-1 ));
326+ auto nonempty_grad = grad_.index_select (0 , bag_size_.nonzero ().view (-1 ));
327+
310328 for (int64_t dim = 0 ; dim < grad.size (1 ); dim++) {
311- index_grad_weight.select (1 , dim).index_add_ (0 , max_indices_.select (1 , dim), grad_.select (1 , dim));
329+ index_grad_weight.select (1 , dim).index_add_ (
330+ 0 , nonempty_max_indices.select (1 , dim), nonempty_grad.select (1 , dim));
312331 }
313332 }
314333
0 commit comments