1414#include < omp.h>
1515#endif
1616
17+ namespace {
18+ const int MODE_SUM = 0 ;
19+ const int MODE_MEAN = 1 ;
20+ const int MODE_MAX = 2 ;
21+ }
22+
1723namespace at {
1824namespace native {
1925
@@ -50,7 +56,7 @@ static void index_select_add(const Tensor &select_indices,
5056 auto src_data = src.data <T>();
5157 auto output_data = output.data <T>();
5258 auto numel = add_indices.numel ();
53- int64_t ddim = src.sizes ()[ 1 ] ;
59+ int64_t ddim = src.size ( 1 ) ;
5460 for (int64_t i = 0 ; i < numel; i++) {
5561 axpy<T>(ddim, 1 , src_data + ddim * select_indices_data[i], 1 ,
5662 output_data + ddim * add_indices_data[i], 1 );
@@ -60,11 +66,11 @@ static void index_select_add(const Tensor &select_indices,
6066static void make_bag_size (const Tensor &offsets, const Tensor &indices,
6167 const int64_t mode, Tensor &bag_size) {
6268 if (mode == 1 ) { // MODE_MEAN
63- if (offsets.sizes ()[ 0 ] != 1 ) {
64- bag_size.slice (0 , 0 , bag_size.sizes ()[ 0 ] - 1 , 1 ) =
65- offsets.slice (0 , 1 , offsets.sizes ()[ 0 ] , 1 ) -
66- offsets.slice (0 , 0 , offsets.sizes ()[ 0 ] - 1 , 1 );
67- bag_size[-1 ] = indices.sizes ()[ 0 ] - offsets[-1 ];
69+ if (offsets.size ( 0 ) != 1 ) {
70+ bag_size.slice (0 , 0 , bag_size.size ( 0 ) - 1 , 1 ) =
71+ offsets.slice (0 , 1 , offsets.size ( 0 ) , 1 ) -
72+ offsets.slice (0 , 0 , offsets.size ( 0 ) - 1 , 1 );
73+ bag_size[-1 ] = indices.size ( 0 ) - offsets[-1 ];
6874 }
6975 }
7076}
@@ -73,8 +79,8 @@ static Tensor apply_bag_size(const Tensor &offsets, const Tensor &indices,
7379 const int64_t mode, Tensor &output,
7480 const Tensor &bag_size) {
7581 if (mode == 1 ) { // MODE_MEAN
76- if (offsets.sizes ()[ 0 ] == 1 ) {
77- auto bag_size_ = indices.sizes ()[ 0 ] ;
82+ if (offsets.size ( 0 ) == 1 ) {
83+ auto bag_size_ = indices.size ( 0 ) ;
7884 output /= bag_size_;
7985 } else {
8086 auto bag_size_ =
@@ -90,8 +96,8 @@ static Tensor apply_bag_size_backward(const Tensor &offsets,
9096 Tensor &output, const Tensor &offset2bag,
9197 const Tensor &bag_size) {
9298 if (mode == 1 ) { // MODE_MEAN
93- if (offsets.sizes ()[ 0 ] == 1 ) {
94- auto bag_size_ = indices.sizes ()[ 0 ] ;
99+ if (offsets.size ( 0 ) == 1 ) {
100+ auto bag_size_ = indices.size ( 0 ) ;
95101 output /= bag_size_;
96102 } else {
97103 auto inv_bag_size_ = (1 / bag_size.toType (output.type ()))
@@ -103,7 +109,48 @@ static Tensor apply_bag_size_backward(const Tensor &offsets,
103109 return output;
104110}
105111
106- std::tuple<Tensor, Tensor, Tensor>
112+
113+ template <typename scalar_t >
114+ std::tuple<Tensor, Tensor, Tensor, Tensor> embedding_bag_cpu_max (
115+ const Tensor& weight, const Tensor &indices, const Tensor& offset2bag, const Tensor& output, const Tensor& bag_size, const Tensor& offsets) {
116+
117+ auto max_indices = at::zeros (indices.type (), {offsets.size (0 ), weight.size (1 )});
118+
119+ int64_t numel = indices.numel ();
120+ int64_t dims = weight.size (1 );
121+ auto indices_data = indices.data <int64_t >();
122+ auto offset2bag_data = offset2bag.data <int64_t >();
123+
124+ auto max_indices_data = max_indices.data <int64_t >();
125+ auto max_indices_stride = max_indices.stride (0 );
126+
127+ auto weight_data = weight.data <scalar_t >();
128+ auto output_data = output.data <scalar_t >();
129+ auto weight_stride = weight.stride (0 );
130+ auto output_stride = output.stride (0 );
131+
132+ for (int i = 0 ; i < numel; i++) {
133+ auto bag = offset2bag_data[i];
134+ auto word_idx = indices_data[i];
135+
136+
137+ for (int dim = 0 ; dim < dims; dim++) {
138+ auto & current_item = output_data[output_stride * bag + dim];
139+ auto weight_item = weight_data[weight_stride * word_idx + dim];
140+
141+ bool is_first_for_bag = (i == 0 ) || offset2bag_data[i - 1 ] != bag;
142+
143+ if (is_first_for_bag || weight_item > current_item) {
144+ current_item = weight_item;
145+ max_indices_data[max_indices_stride * bag + dim] = word_idx;
146+ }
147+ }
148+ }
149+
150+ return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, offset2bag, bag_size, max_indices);
151+ }
152+
153+ std::tuple<Tensor, Tensor, Tensor, Tensor>
107154embedding_bag_cpu (const Tensor &weight, const Tensor &indices__,
108155 const Tensor &offsets__, const bool scale_grad_by_freq,
109156 const int64_t mode, bool sparse) {
@@ -118,23 +165,34 @@ embedding_bag_cpu(const Tensor &weight, const Tensor &indices__,
118165
119166 auto bag_size = at::zeros (indices.type (), offsets.sizes ());
120167 auto offset2bag =
121- at::zeros (indices__.type (), {indices.sizes ()[ 0 ] }); // offset2bag = [0 0 0 0 0]
168+ at::zeros (indices__.type (), {indices.size ( 0 ) }); // offset2bag = [0 0 0 0 0]
122169 make_offset2bag (offsets, indices, offset2bag);
123- auto output = at::zeros (weight.type (), {offsets.sizes ()[0 ], weight.sizes ()[1 ]});
124- if (weight.type ().scalarType () == kFloat ) {
125- index_select_add<float >(indices, offset2bag, weight, output);
126- } else if (weight.type ().scalarType () == kDouble ) {
127- index_select_add<double >(indices, offset2bag, weight, output);
170+ auto output = at::zeros (weight.type (), {offsets.size (0 ), weight.size (1 )});
171+
172+ if (mode == MODE_MEAN || mode == MODE_SUM) {
173+ if (weight.type ().scalarType () == kFloat ) {
174+ index_select_add<float >(indices, offset2bag, weight, output);
175+ } else if (weight.type ().scalarType () == kDouble ) {
176+ index_select_add<double >(indices, offset2bag, weight, output);
177+ }
178+ make_bag_size (offsets, indices, mode, bag_size);
179+ auto ret = apply_bag_size (offsets, indices, mode, output, bag_size);
180+ return std::tuple<Tensor, Tensor, Tensor, Tensor>(ret, offset2bag, bag_size, bag_size);
181+ } else { // MODE_MAX
182+ return AT_DISPATCH_FLOATING_TYPES_AND_HALF (
183+ weight.type (), " embedding_bag_cpu_max" , [&]() {
184+ return embedding_bag_cpu_max<scalar_t >(weight, indices, offset2bag, output, bag_size, offsets);
185+ }
186+ );
128187 }
129- make_bag_size (offsets, indices, mode, bag_size);
130- auto ret = apply_bag_size (offsets, indices, mode, output, bag_size);
131- return std::tuple<Tensor, Tensor, Tensor>(ret, offset2bag, bag_size);
132188}
133189
134190Tensor embedding_bag_backward (const Tensor &grad_, const Tensor &indices__,
135191 const Tensor &offsets__,
136192 const Tensor &offset2bag__,
137- const Tensor &bag_size_, int64_t num_weights,
193+ const Tensor &bag_size_,
194+ const Tensor &max_indices_,
195+ int64_t num_weights,
138196 bool scale_grad_by_freq, int64_t mode,
139197 bool sparse) {
140198 auto indices_arg = TensorArg (indices__, " indices__" , 1 );
@@ -153,15 +211,16 @@ Tensor embedding_bag_backward(const Tensor &grad_, const Tensor &indices__,
153211 scale_grad_by_freq, mode);
154212 } else {
155213 return at::embedding_bag_dense_backward (
156- grad_, indices, offsets, offset2bag__, bag_size_, num_weights,
214+ grad_, indices, offsets, offset2bag__, bag_size_, max_indices_, num_weights,
157215 scale_grad_by_freq, mode);
158216 }
159217}
160218
161219Tensor embedding_bag_backward_cpu (const Tensor &grad_, const Tensor &indices__,
162220 const Tensor &offsets__,
163221 const Tensor &offset2bag__,
164- const Tensor &bag_size_, int64_t num_weights,
222+ const Tensor &bag_size_,
223+ const Tensor& max_indices_, int64_t num_weights,
165224 bool scale_grad_by_freq, int64_t mode) {
166225 auto grad = grad_.contiguous ();
167226 auto grad_arg = TensorArg (grad, " grad_" , 1 );
@@ -196,6 +255,9 @@ Tensor embedding_bag_backward_cpu(const Tensor &grad_, const Tensor &indices__,
196255 counts[indices_data[i]]++;
197256 }
198257
258+ auto index_grad_weight =
259+ at::zeros (grad.type (), {num_weights, grad.size (1 )}).contiguous ();
260+
199261 std::vector<int64_t > counts_uniq;
200262 counts_uniq.reserve (num_weights);
201263 int64_t o = 0 ;
@@ -207,43 +269,46 @@ Tensor embedding_bag_backward_cpu(const Tensor &grad_, const Tensor &indices__,
207269 o++;
208270 }
209271
210- auto index_grad_weight =
211- at::zeros (grad.type (), {num_weights, grad.sizes ()[1 ]}).contiguous ();
212-
213- #pragma omp parallel for if (numel > 1000)
214- for (int64_t i = 0 ; i < (int64_t )counts_uniq.size (); i++) {
215- int64_t start = i == 0 ? 0 : counts_uniq[i - 1 ];
216- int64_t index = indices_data[start];
217- for (int64_t j = start; j < counts_uniq[i]; j++) {
218- int64_t source = offset2bag_data[j];
219- double scale = 1.0 ;
220- if (scale_grad_by_freq) {
221- scale /= counts[indices_data[i]];
222- }
223- if (mode == 1 ) { // MODE_MEAN
224- if (offsets_.sizes ()[0 ] == 1 ) {
225- auto bag_size = indices.sizes ()[0 ];
226- scale /= bag_size;
227- } else {
228- if (source == offsets_.sizes ()[0 ] - 1 ) {
229- scale /= indices.sizes ()[0 ] - offsets_data[offsets_.sizes ()[0 ] - 1 ];
230- } else {
231- scale /= offsets_data[source + 1 ] - offsets_data[source];
272+ if (mode == MODE_MEAN || mode == MODE_SUM) {
273+ #pragma omp parallel for if (numel > 1000)
274+ for (int64_t i = 0 ; i < (int64_t )counts_uniq.size (); i++) {
275+ int64_t start = i == 0 ? 0 : counts_uniq[i - 1 ];
276+ int64_t index = indices_data[start];
277+ for (int64_t j = start; j < counts_uniq[i]; j++) {
278+ int64_t source = offset2bag_data[j];
279+ double scale = 1.0 ;
280+ if (scale_grad_by_freq) {
281+ scale /= counts[indices_data[i]];
282+ }
283+ if (mode == 1 ) { // MODE_MEAN
284+ if (offsets_.size (0 ) == 1 ) {
285+ auto bag_size = indices.size (0 );
286+ scale /= bag_size;
287+ } else {
288+ if (source == offsets_.size (0 ) - 1 ) {
289+ scale /= indices.size (0 ) - offsets_data[offsets_.size (0 ) - 1 ];
290+ } else {
291+ scale /= offsets_data[source + 1 ] - offsets_data[source];
292+ }
293+ }
294+ }
295+ int64_t ddim = grad.size (1 );
296+ if (grad.type ().scalarType () == kFloat ) {
297+ auto igwd = index_grad_weight.data <float >();
298+ auto gd = grad.data <float >();
299+ axpy<float >(ddim, (float )scale, gd + ddim * source, 1 ,
300+ igwd + ddim * index, 1 );
301+ } else if (grad.type ().scalarType () == kDouble ) {
302+ auto igwd = index_grad_weight.data <double >();
303+ auto gd = grad.data <double >();
304+ axpy<double >(ddim, (double )scale, gd + ddim * source, 1 ,
305+ igwd + ddim * index, 1 );
232306 }
233307 }
234- }
235- int64_t ddim = grad.sizes ()[1 ];
236- if (grad.type ().scalarType () == kFloat ) {
237- auto igwd = index_grad_weight.data <float >();
238- auto gd = grad.data <float >();
239- axpy<float >(ddim, (float )scale, gd + ddim * source, 1 ,
240- igwd + ddim * index, 1 );
241- } else if (grad.type ().scalarType () == kDouble ) {
242- auto igwd = index_grad_weight.data <double >();
243- auto gd = grad.data <double >();
244- axpy<double >(ddim, (double )scale, gd + ddim * source, 1 ,
245- igwd + ddim * index, 1 );
246- }
308+ }
309+ } else if (mode == MODE_MAX) {
310+ for (int64_t dim = 0 ; dim < grad.size (1 ); dim++) {
311+ index_grad_weight.select (1 , dim).index_add_ (0 , max_indices_.select (1 , dim), grad_.select (1 , dim));
247312 }
248313 }
249314
0 commit comments