Skip to content

Commit 39ac734

Browse files
A. Unique TensorFlowerVijay Vasudevan
authored andcommitted
Adds support for SQRTN combiner.
The implementation divides the weighted sum by sqrt(n) or divides by the sqrt(sum(x^2)) if sp_weights are specified on embedding_lookup_sparse. Also implements math_ops.sparse_segment_sqrtn and its corresponding math_ops.sparse_segment_sqrtn_grad. Change: 111889513
1 parent 5853ad9 commit 39ac734

8 files changed

Lines changed: 252 additions & 42 deletions

File tree

tensorflow/core/kernels/segment_reduction_ops.cc

Lines changed: 68 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -250,8 +250,8 @@ template <typename Device, class T>
250250
class SparseSegmentReductionOpBase : public OpKernel {
251251
public:
252252
explicit SparseSegmentReductionOpBase(OpKernelConstruction* context,
253-
bool is_mean)
254-
: OpKernel(context), is_mean_(is_mean) {}
253+
bool is_mean, bool is_sqrtn)
254+
: OpKernel(context), is_mean_(is_mean), is_sqrtn_(is_sqrtn) {}
255255

256256
void Compute(OpKernelContext* context) override {
257257
const Tensor& input = context->input(0);
@@ -309,7 +309,13 @@ class SparseSegmentReductionOpBase : public OpKernel {
309309
out = I(0);
310310
} else {
311311
int r = num % 8;
312-
T m = (is_mean_ && (num < 10)) ? num : 1;
312+
T m = 1;
313+
if (is_mean_ && (num < 10)) {
314+
m = num;
315+
}
316+
if (is_sqrtn_ && (num < 10)) {
317+
m = sqrt(num);
318+
}
313319
switch (r) {
314320
case 2:
315321
out = (I(0) + I(1)) / m;
@@ -348,30 +354,45 @@ class SparseSegmentReductionOpBase : public OpKernel {
348354
if (is_mean_ && num >= 10) {
349355
out = out / static_cast<T>(num);
350356
}
357+
if (is_sqrtn_ && num >= 10) {
358+
out = out / static_cast<T>(sqrt(num));
359+
}
351360
}
352361
start = end;
353362
++end;
354363
}
355364
}
356365

357366
private:
358-
bool is_mean_;
367+
const bool is_mean_;
368+
const bool is_sqrtn_;
359369
};
360370

361371
template <typename Device, class T>
362372
class SparseSegmentReductionMeanOp
363373
: public SparseSegmentReductionOpBase<Device, T> {
364374
public:
365375
explicit SparseSegmentReductionMeanOp(OpKernelConstruction* context)
366-
: SparseSegmentReductionOpBase<Device, T>(context, true /*is_mean*/) {}
376+
: SparseSegmentReductionOpBase<Device, T>(context, true /*is_mean*/,
377+
false /*is_sqrtn*/) {}
378+
};
379+
380+
template <typename Device, class T>
381+
class SparseSegmentReductionSqrtNOp
382+
: public SparseSegmentReductionOpBase<Device, T> {
383+
public:
384+
explicit SparseSegmentReductionSqrtNOp(OpKernelConstruction* context)
385+
: SparseSegmentReductionOpBase<Device, T>(context, false /*is_mean*/,
386+
true /*is_sqrtn*/) {}
367387
};
368388

369389
template <typename Device, class T>
370390
class SparseSegmentReductionSumOp
371391
: public SparseSegmentReductionOpBase<Device, T> {
372392
public:
373393
explicit SparseSegmentReductionSumOp(OpKernelConstruction* context)
374-
: SparseSegmentReductionOpBase<Device, T>(context, false /*is_mean*/) {}
394+
: SparseSegmentReductionOpBase<Device, T>(context, false /*is_mean*/,
395+
false /*is_sqrtn*/) {}
375396
};
376397

377398
#define REGISTER_CPU_SPARSE_KERNELS(type) \
@@ -390,11 +411,19 @@ REGISTER_CPU_SPARSE_KERNELS(float);
390411
REGISTER_CPU_SPARSE_KERNELS(double);
391412
#undef REGISTER_CPU_SPARSE_KERNELS
392413

414+
#define REGISTER_CPU_SPARSE_KERNELS(type) \
415+
REGISTER_KERNEL_BUILDER( \
416+
Name("SparseSegmentSqrtN").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
417+
SparseSegmentReductionSqrtNOp<CPUDevice, type>);
418+
REGISTER_CPU_SPARSE_KERNELS(float);
419+
REGISTER_CPU_SPARSE_KERNELS(double);
420+
#undef REGISTER_CPU_SPARSE_KERNELS
421+
393422
template <class T>
394-
class SparseSegmentMeanGradOp : public OpKernel {
423+
class SparseSegmentGradOpBase : public OpKernel {
395424
public:
396-
explicit SparseSegmentMeanGradOp(OpKernelConstruction* context)
397-
: OpKernel(context) {}
425+
explicit SparseSegmentGradOpBase(OpKernelConstruction* context, bool is_sqrtn)
426+
: OpKernel(context), is_sqrtn_(is_sqrtn) {}
398427

399428
void Compute(OpKernelContext* context) override {
400429
const Tensor& input = context->input(0);
@@ -437,7 +466,11 @@ class SparseSegmentMeanGradOp : public OpKernel {
437466
scaling[segment_vec(i)] += 1;
438467
}
439468
for (int i = 0; i < scaling.size(); ++i) {
440-
scaling[i] = 1.0 / std::max(scaling[i], 1.0);
469+
if (is_sqrtn_) {
470+
scaling[i] = 1.0 / sqrt(std::max(scaling[i], 1.0));
471+
} else {
472+
scaling[i] = 1.0 / std::max(scaling[i], 1.0);
473+
}
441474
}
442475

443476
auto output_flat = output->flat_outer_dims<T>();
@@ -468,16 +501,40 @@ class SparseSegmentMeanGradOp : public OpKernel {
468501
is_modified[output_idx] = true;
469502
}
470503
}
504+
505+
private:
506+
const bool is_sqrtn_;
507+
};
508+
509+
template <class T>
510+
class SparseSegmentMeanGradOp : public SparseSegmentGradOpBase<T> {
511+
public:
512+
explicit SparseSegmentMeanGradOp(OpKernelConstruction* context)
513+
: SparseSegmentGradOpBase<T>(context, false /*is_sqrtn*/) {}
514+
};
515+
516+
template <class T>
517+
class SparseSegmentSqrtNGradOp : public SparseSegmentGradOpBase<T> {
518+
public:
519+
explicit SparseSegmentSqrtNGradOp(OpKernelConstruction* context)
520+
: SparseSegmentGradOpBase<T>(context, true /*is_sqrtn*/) {}
471521
};
472522

473523
#define REGISTER_CPU_SPARSE_KERNELS(type) \
474524
REGISTER_KERNEL_BUILDER(Name("SparseSegmentMeanGrad") \
475525
.Device(DEVICE_CPU) \
476526
.TypeConstraint<type>("T"), \
477527
SparseSegmentMeanGradOp<type>);
478-
479528
REGISTER_CPU_SPARSE_KERNELS(float);
480529
REGISTER_CPU_SPARSE_KERNELS(double);
530+
#undef REGISTER_CPU_SPARSE_KERNELS
481531

532+
#define REGISTER_CPU_SPARSE_KERNELS(type) \
533+
REGISTER_KERNEL_BUILDER(Name("SparseSegmentSqrtNGrad") \
534+
.Device(DEVICE_CPU) \
535+
.TypeConstraint<type>("T"), \
536+
SparseSegmentSqrtNGradOp<type>);
537+
REGISTER_CPU_SPARSE_KERNELS(float);
538+
REGISTER_CPU_SPARSE_KERNELS(double);
482539
#undef REGISTER_CPU_SPARSE_KERNELS
483540
} // namespace tensorflow

tensorflow/core/ops/math_grad.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,7 @@ REGISTER_OP_GRADIENT("Mean", MeanGrad);
468468
// REGISTER_OP_GRADIENT("SegmentMean", SegmentMeanGrad);
469469
// REGISTER_OP_GRADIENT("SparseSegmentSum", SparseSegmentSumGrad);
470470
// REGISTER_OP_GRADIENT("SparseSegmentMean", SparseSegmentMeanGrad);
471+
// REGISTER_OP_GRADIENT("SparseSegmentSqrtN", SparseSegmentSqrtNGrad);
471472
// REGISTER_OP_GRADIENT("SegmentMin", SegmentMinGrad);
472473
// REGISTER_OP_GRADIENT("SegmentMax", SegmentMaxGrad);
473474
// REGISTER_OP_GRADIENT("UnsortedSegmentSum", UnsortedSegmentSumGrad);

tensorflow/core/ops/math_ops.cc

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -921,6 +921,49 @@ segment_ids: segment_ids passed to the corresponding SparseSegmentMean op.
921921
output_dim0: dimension 0 of "data" passed to SparseSegmentMean op.
922922
)doc");
923923

924+
REGISTER_OP("SparseSegmentSqrtN")
925+
.Input("data: T")
926+
.Input("indices: int32")
927+
.Input("segment_ids: int32")
928+
.Output("output: T")
929+
.Attr("T: {float, double}")
930+
.Doc(R"doc(
931+
Computes the sum along sparse segments of a tensor divided by the sqrt of N.
932+
933+
N is the size of the segment being reduced.
934+
935+
Read [the section on
936+
Segmentation](../../api_docs/python/math_ops.md#segmentation) for an explanation
937+
of segments.
938+
939+
indices: A 1-D tensor. Has same rank as `segment_ids`.
940+
941+
segment_ids: A 1-D tensor. Values should be sorted and can be repeated.
942+
943+
output: Has same shape as data, except for dimension 0 which
944+
has size `k`, the number of segments.
945+
946+
)doc");
947+
948+
REGISTER_OP("SparseSegmentSqrtNGrad")
949+
.Input("grad: T")
950+
.Input("indices: int32")
951+
.Input("segment_ids: int32")
952+
.Input("output_dim0: int32")
953+
.Output("output: T")
954+
.Attr("T: {float, double}")
955+
.Doc(R"doc(
956+
Computes gradients for SparseSegmentSqrtN.
957+
958+
Returns tensor "output" with same shape as grad, except for dimension 0 whose
959+
value is output_dim0.
960+
961+
grad: gradient propagated to the SparseSegmentSqrtN op.
962+
indices: indices passed to the corresponding SparseSegmentSqrtN op.
963+
segment_ids: segment_ids passed to the corresponding SparseSegmentSqrtN op.
964+
output_dim0: dimension 0 of "data" passed to SparseSegmentSqrtN op.
965+
)doc");
966+
924967
REGISTER_OP("All")
925968
.Input("input: bool")
926969
.Input("reduction_indices: int32")

tensorflow/core/ops/ops.pbtxt

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7894,6 +7894,79 @@ op {
78947894
summary: "Computes gradients for SparseSegmentMean."
78957895
description: "Returns tensor \"output\" with same shape as grad, except for dimension 0 whose\nvalue is output_dim0."
78967896
}
7897+
op {
7898+
name: "SparseSegmentSqrtN"
7899+
input_arg {
7900+
name: "data"
7901+
type_attr: "T"
7902+
}
7903+
input_arg {
7904+
name: "indices"
7905+
description: "A 1-D tensor. Has same rank as `segment_ids`."
7906+
type: DT_INT32
7907+
}
7908+
input_arg {
7909+
name: "segment_ids"
7910+
description: "A 1-D tensor. Values should be sorted and can be repeated."
7911+
type: DT_INT32
7912+
}
7913+
output_arg {
7914+
name: "output"
7915+
description: "Has same shape as data, except for dimension 0 which\nhas size `k`, the number of segments."
7916+
type_attr: "T"
7917+
}
7918+
attr {
7919+
name: "T"
7920+
type: "type"
7921+
allowed_values {
7922+
list {
7923+
type: DT_FLOAT
7924+
type: DT_DOUBLE
7925+
}
7926+
}
7927+
}
7928+
summary: "Computes the sum along sparse segments of a tensor divided by the sqrt of N."
7929+
description: "N is the size of the segment being reduced.\n\nRead [the section on\nSegmentation](../../api_docs/python/math_ops.md#segmentation) for an explanation\nof segments."
7930+
}
7931+
op {
7932+
name: "SparseSegmentSqrtNGrad"
7933+
input_arg {
7934+
name: "grad"
7935+
description: "gradient propagated to the SparseSegmentSqrtN op."
7936+
type_attr: "T"
7937+
}
7938+
input_arg {
7939+
name: "indices"
7940+
description: "indices passed to the corresponding SparseSegmentSqrtN op."
7941+
type: DT_INT32
7942+
}
7943+
input_arg {
7944+
name: "segment_ids"
7945+
description: "segment_ids passed to the corresponding SparseSegmentSqrtN op."
7946+
type: DT_INT32
7947+
}
7948+
input_arg {
7949+
name: "output_dim0"
7950+
description: "dimension 0 of \"data\" passed to SparseSegmentSqrtN op."
7951+
type: DT_INT32
7952+
}
7953+
output_arg {
7954+
name: "output"
7955+
type_attr: "T"
7956+
}
7957+
attr {
7958+
name: "T"
7959+
type: "type"
7960+
allowed_values {
7961+
list {
7962+
type: DT_FLOAT
7963+
type: DT_DOUBLE
7964+
}
7965+
}
7966+
}
7967+
summary: "Computes gradients for SparseSegmentSqrtN."
7968+
description: "Returns tensor \"output\" with same shape as grad, except for dimension 0 whose\nvalue is output_dim0."
7969+
}
78977970
op {
78987971
name: "SparseSegmentSum"
78997972
input_arg {

0 commit comments

Comments
 (0)