@@ -177,7 +177,7 @@ embedding_bag_cpu(const Tensor &weight, const Tensor &indices__,
177177 }
178178 make_bag_size (offsets, indices, mode, bag_size);
179179 auto ret = apply_bag_size (offsets, indices, mode, output, bag_size);
180- return std::tuple<Tensor, Tensor, Tensor>(ret, offset2bag, bag_size);
180+ return std::tuple<Tensor, Tensor, Tensor, Tensor >(ret, offset2bag, bag_size , bag_size);
181181 } else { // MODE_MAX
182182 return AT_DISPATCH_FLOATING_TYPES_AND_HALF (
183183 weight.type (), " embedding_bag_cpu_max" , [&]() {
@@ -296,37 +296,19 @@ Tensor embedding_bag_backward_cpu(const Tensor &grad_, const Tensor &indices__,
296296 if (grad.type ().scalarType () == kFloat ) {
297297 auto igwd = index_grad_weight.data <float >();
298298 auto gd = grad.data <float >();
299- THFloatBlas_axpy (ddim, (float )scale, gd + ddim * source, 1 ,
300- igwd + ddim * index, 1 );
299+ axpy< float > (ddim, (float )scale, gd + ddim * source, 1 ,
300+ igwd + ddim * index, 1 );
301301 } else if (grad.type ().scalarType () == kDouble ) {
302302 auto igwd = index_grad_weight.data <double >();
303303 auto gd = grad.data <double >();
304- THDoubleBlas_axpy (ddim, (double )scale, gd + ddim * source, 1 ,
305- igwd + ddim * index, 1 );
306- } else {
307- index_grad_weight[index].add_ (grad[source], scale);
304+ axpy<double >(ddim, (double )scale, gd + ddim * source, 1 ,
305+ igwd + ddim * index, 1 );
308306 }
309307 }
310- <<<<<<< HEAD
311- }
312- int64_t ddim = grad.sizes ()[1 ];
313- if (grad.type ().scalarType () == kFloat ) {
314- auto igwd = index_grad_weight.data <float >();
315- auto gd = grad.data <float >();
316- axpy<float >(ddim, (float )scale, gd + ddim * source, 1 ,
317- igwd + ddim * index, 1 );
318- } else if (grad.type ().scalarType () == kDouble ) {
319- auto igwd = index_grad_weight.data <double >();
320- auto gd = grad.data <double >();
321- axpy<double >(ddim, (double )scale, gd + ddim * source, 1 ,
322- igwd + ddim * index, 1 );
323- }
324- =======
325308 }
326309 } else if (mode == MODE_MAX) {
327310 for (int64_t dim = 0 ; dim < grad.sizes ()[1 ]; dim++) {
328311 index_grad_weight.select (1 , dim).index_add_ (0 , max_indices_.select (1 , dim), grad_.select (1 , dim));
329- >>>>>>> Add max mode support to EmbeddingBag
330312 }
331313 }
332314
0 commit comments