Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions aten/src/ATen/TensorUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,26 @@ void checkScalarType(CheckedFrom c, const TensorArg& t, ScalarType ty) {
}
}

void checkScalarTypes(CheckedFrom c, const TensorArg& t,
at::ArrayRef<ScalarType> l) {
if (std::find(l.begin(), l.end(), t->type().scalarType()) == l.end()) {
std::ostringstream oss;
oss << "Expected tensor for " << t << " to have one of the following "
<< "scalar types: ";
size_t i = 0;
for (auto ty : l) {
if (i != 0) {
oss << ", ";
}
oss << toString(ty);
i++;
}
oss << "; but got " << t->toString()
<< " instead (while checking arguments for " << c << ")";
throw std::runtime_error(oss.str());
}
}

void checkAllSameType(CheckedFrom c, ArrayRef<TensorArg> tensors) {
checkAllSame(c, tensors, checkSameType);
}
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/TensorUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ void checkNumel(CheckedFrom c, const TensorGeometryArg& t, int64_t numel);
void checkSameNumel(CheckedFrom c, const TensorGeometryArg& t1, const TensorGeometryArg& t2);
void checkAllSameNumel(CheckedFrom c, ArrayRef<TensorArg> tensors);
void checkScalarType(CheckedFrom c, const TensorArg& t, ScalarType s);
void checkScalarTypes(CheckedFrom c, const TensorArg& t, at::ArrayRef<ScalarType> l);
void checkSameGPU(CheckedFrom c, const TensorArg& t1, const TensorArg& t2);
void checkAllSameGPU(CheckedFrom c, ArrayRef<TensorArg> tensors);
void checkSameType(CheckedFrom c, const TensorArg& t1, const TensorArg& t2);
Expand All @@ -78,4 +79,3 @@ void * maybe_data_ptr(const Tensor& tensor);
void * maybe_data_ptr(const TensorArg& tensor);

}

62 changes: 50 additions & 12 deletions aten/src/ATen/native/EmbeddingBag.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,39 @@ static void make_offset2bag(const Tensor &offsets, const Tensor &indices,
offset2bag = offset2bag.cumsum(0); // offset2bag = [0 0 1 1 2]
}

template<typename T>
static void axpy(int64_t n, T a, T *x, int64_t incx, T *y, int64_t incy);
template<>
void axpy<float>(int64_t n, float a, float *x, int64_t incx,
float *y, int64_t incy) {
THFloatBlas_axpy(n, a, x, incx, y, incy);
}
template<>
void axpy<double>(int64_t n, double a, double *x, int64_t incx,
double *y, int64_t incy) {
THDoubleBlas_axpy(n, a, x, incx, y, incy);
}

// This function combines index_select (using select_indices as the index) and
// index_add (using add_indices as the index), without creating an intermediary
// tensor to hold the selected embeddings
template<typename T>
static void index_select_add(const Tensor &select_indices,
const Tensor &add_indices,
const Tensor &src,
Tensor &output) {
auto add_indices_data = add_indices.data<int64_t>();
auto select_indices_data = select_indices.data<int64_t>();
auto src_data = src.data<T>();
auto output_data = output.data<T>();
auto numel = add_indices.numel();
int64_t ddim = src.sizes()[1];
for (int64_t i = 0; i < numel; i++) {
axpy<T>(ddim, 1, src_data + ddim * select_indices_data[i], 1,
output_data + ddim * add_indices_data[i], 1);
}
}

static void make_bag_size(const Tensor &offsets, const Tensor &indices,
const int64_t mode, Tensor &bag_size) {
if (mode == 1) { // MODE_MEAN
Expand Down Expand Up @@ -61,10 +94,10 @@ static Tensor apply_bag_size_backward(const Tensor &offsets,
auto bag_size_ = indices.sizes()[0];
output /= bag_size_;
} else {
auto bag_size_ = bag_size.toType(output.type())
.unsqueeze(1)
.index_select(0, offset2bag);
output /= bag_size_;
auto inv_bag_size_ = (1 / bag_size.toType(output.type()))
.unsqueeze(1)
.index_select(0, offset2bag);
output *= inv_bag_size_;
}
}
return output;
Expand All @@ -80,14 +113,19 @@ embedding_bag_cpu(const Tensor &weight, const Tensor &indices__,
checkScalarType("embedding_bag", offsets_arg, kLong);
Tensor indices = indices__.contiguous();
Tensor offsets = offsets__.contiguous();
auto weight_arg = TensorArg(weight, "weight", 1);
checkScalarTypes("embedding_bag", weight_arg, {kFloat, kDouble});

This comment was marked as off-topic.

This comment was marked as off-topic.


auto bag_size = at::zeros(indices.type(), offsets.sizes());
auto offset2bag =
at::zeros(indices__.type(), {indices.sizes()[0]}); // offset2bag = [0 0 0 0 0]
make_offset2bag(offsets, indices, offset2bag);
auto output = at::zeros(weight.type(), {offsets.sizes()[0], weight.sizes()[1]});
auto index_output = weight.index_select(0, indices);
output.index_add_(0, offset2bag, index_output);
if (weight.type().scalarType() == kFloat) {
index_select_add<float>(indices, offset2bag, weight, output);
} else if (weight.type().scalarType() == kDouble) {
index_select_add<double>(indices, offset2bag, weight, output);
}
make_bag_size(offsets, indices, mode, bag_size);
auto ret = apply_bag_size(offsets, indices, mode, output, bag_size);
return std::tuple<Tensor, Tensor, Tensor>(ret, offset2bag, bag_size);
Expand Down Expand Up @@ -126,6 +164,8 @@ Tensor embedding_bag_backward_cpu(const Tensor &grad_, const Tensor &indices__,
const Tensor &bag_size_, int64_t num_weights,
bool scale_grad_by_freq, int64_t mode) {
auto grad = grad_.contiguous();
auto grad_arg = TensorArg(grad, "grad_", 1);
checkScalarTypes("embedding_bag", grad_arg, {kFloat, kDouble});
auto indices_arg = TensorArg(indices__, "indices__", 1);
checkScalarType("embedding_bag", indices_arg, kLong);
auto offsets_arg = TensorArg(offsets__, "offsets__", 1);
Expand Down Expand Up @@ -196,15 +236,13 @@ Tensor embedding_bag_backward_cpu(const Tensor &grad_, const Tensor &indices__,
if (grad.type().scalarType() == kFloat) {
auto igwd = index_grad_weight.data<float>();
auto gd = grad.data<float>();
THFloatBlas_axpy(ddim, (float)scale, gd + ddim * source, 1,
igwd + ddim * index, 1);
axpy<float>(ddim, (float)scale, gd + ddim * source, 1,
igwd + ddim * index, 1);
} else if (grad.type().scalarType() == kDouble) {
auto igwd = index_grad_weight.data<double>();
auto gd = grad.data<double>();
THDoubleBlas_axpy(ddim, (double)scale, gd + ddim * source, 1,
igwd + ddim * index, 1);
} else {
index_grad_weight[index].add_(grad[source], scale);
axpy<double>(ddim, (double)scale, gd + ddim * source, 1,
igwd + ddim * index, 1);
}
}
}
Expand Down