@@ -24,6 +24,39 @@ static void make_offset2bag(const Tensor &offsets, const Tensor &indices,
2424 offset2bag = offset2bag.cumsum (0 ); // offset2bag = [0 0 1 1 2]
2525}
2626
27+ template <typename T>
28+ static void axpy (int64_t n, T a, T *x, int64_t incx, T *y, int64_t incy);
29+ template <>
30+ void axpy<float >(int64_t n, float a, float *x, int64_t incx,
31+ float *y, int64_t incy) {
32+ THFloatBlas_axpy (n, a, x, incx, y, incy);
33+ }
34+ template <>
35+ void axpy<double >(int64_t n, double a, double *x, int64_t incx,
36+ double *y, int64_t incy) {
37+ THDoubleBlas_axpy (n, a, x, incx, y, incy);
38+ }
39+
40+ // This function combines index_select (using select_indices as the index) and
41+ // index_add (using add_indices as the index), without creating an intermediary
42+ // tensor to hold the selected embeddings
43+ template <typename T>
44+ static void index_select_add (const Tensor &select_indices,
45+ const Tensor &add_indices,
46+ const Tensor &src,
47+ Tensor &output) {
48+ auto add_indices_data = add_indices.data <int64_t >();
49+ auto select_indices_data = select_indices.data <int64_t >();
50+ auto src_data = src.data <T>();
51+ auto output_data = output.data <T>();
52+ auto numel = add_indices.numel ();
53+ int64_t ddim = src.sizes ()[1 ];
54+ for (int64_t i = 0 ; i < numel; i++) {
55+ axpy<T>(ddim, 1 , src_data + ddim * select_indices_data[i], 1 ,
56+ output_data + ddim * add_indices_data[i], 1 );
57+ }
58+ }
59+
2760static void make_bag_size (const Tensor &offsets, const Tensor &indices,
2861 const int64_t mode, Tensor &bag_size) {
2962 if (mode == 1 ) { // MODE_MEAN
@@ -61,10 +94,10 @@ static Tensor apply_bag_size_backward(const Tensor &offsets,
6194 auto bag_size_ = indices.sizes ()[0 ];
6295 output /= bag_size_;
6396 } else {
64- auto bag_size_ = bag_size.toType (output.type ())
65- .unsqueeze (1 )
66- .index_select (0 , offset2bag);
67- output /= bag_size_ ;
97+ auto inv_bag_size_ = ( 1 / bag_size.toType (output.type () ))
98+ .unsqueeze (1 )
99+ .index_select (0 , offset2bag);
100+ output *= inv_bag_size_ ;
68101 }
69102 }
70103 return output;
@@ -80,14 +113,19 @@ embedding_bag_cpu(const Tensor &weight, const Tensor &indices__,
80113 checkScalarType (" embedding_bag" , offsets_arg, kLong );
81114 Tensor indices = indices__.contiguous ();
82115 Tensor offsets = offsets__.contiguous ();
116+ auto weight_arg = TensorArg (weight, " weight" , 1 );
117+ checkScalarTypes (" embedding_bag" , weight_arg, {kFloat , kDouble });
83118
84119 auto bag_size = at::zeros (indices.type (), offsets.sizes ());
85120 auto offset2bag =
86121 at::zeros (indices__.type (), {indices.sizes ()[0 ]}); // offset2bag = [0 0 0 0 0]
87122 make_offset2bag (offsets, indices, offset2bag);
88123 auto output = at::zeros (weight.type (), {offsets.sizes ()[0 ], weight.sizes ()[1 ]});
89- auto index_output = weight.index_select (0 , indices);
90- output.index_add_ (0 , offset2bag, index_output);
124+ if (weight.type ().scalarType () == kFloat ) {
125+ index_select_add<float >(indices, offset2bag, weight, output);
126+ } else if (weight.type ().scalarType () == kDouble ) {
127+ index_select_add<double >(indices, offset2bag, weight, output);
128+ }
91129 make_bag_size (offsets, indices, mode, bag_size);
92130 auto ret = apply_bag_size (offsets, indices, mode, output, bag_size);
93131 return std::tuple<Tensor, Tensor, Tensor>(ret, offset2bag, bag_size);
@@ -126,6 +164,8 @@ Tensor embedding_bag_backward_cpu(const Tensor &grad_, const Tensor &indices__,
126164 const Tensor &bag_size_, int64_t num_weights,
127165 bool scale_grad_by_freq, int64_t mode) {
128166 auto grad = grad_.contiguous ();
167+ auto grad_arg = TensorArg (grad, " grad_" , 1 );
168+ checkScalarTypes (" embedding_bag" , grad_arg, {kFloat , kDouble });
129169 auto indices_arg = TensorArg (indices__, " indices__" , 1 );
130170 checkScalarType (" embedding_bag" , indices_arg, kLong );
131171 auto offsets_arg = TensorArg (offsets__, " offsets__" , 1 );
@@ -196,15 +236,13 @@ Tensor embedding_bag_backward_cpu(const Tensor &grad_, const Tensor &indices__,
196236 if (grad.type ().scalarType () == kFloat ) {
197237 auto igwd = index_grad_weight.data <float >();
198238 auto gd = grad.data <float >();
199- THFloatBlas_axpy (ddim, (float )scale, gd + ddim * source, 1 ,
200- igwd + ddim * index, 1 );
239+ axpy< float > (ddim, (float )scale, gd + ddim * source, 1 ,
240+ igwd + ddim * index, 1 );
201241 } else if (grad.type ().scalarType () == kDouble ) {
202242 auto igwd = index_grad_weight.data <double >();
203243 auto gd = grad.data <double >();
204- THDoubleBlas_axpy (ddim, (double )scale, gd + ddim * source, 1 ,
205- igwd + ddim * index, 1 );
206- } else {
207- index_grad_weight[index].add_ (grad[source], scale);
244+ axpy<double >(ddim, (double )scale, gd + ddim * source, 1 ,
245+ igwd + ddim * index, 1 );
208246 }
209247 }
210248 }
0 commit comments