Skip to content

Commit e2c9d55

Browse files
ukoxyztensorflower-gardener
authored andcommitted
[XLA] Enable rewriting sparse gradient accumulation with reshapes.
A change in AlgebraicSimplifier. Moving the reshapes can help enable the existing rewrite rule for sparse gradient accumulation. PiperOrigin-RevId: 387692854 Change-Id: I0ec3c8f0bfa1f685a1718309fffaa21ac80e5ecd
1 parent a6e0220 commit e2c9d55

File tree

3 files changed

+90
-0
lines changed

3 files changed

+90
-0
lines changed

tensorflow/compiler/xla/service/algebraic_simplifier.cc

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4055,6 +4055,60 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) {
40554055
}
40564056
}
40574057

4058+
// Moves the reshape in reshape(dus(...), x, ...)) before dus so that it can
4059+
// enable other optimizations, e.g., merging with broadcast, and sparse update
4060+
// (add(x, dus(broadcast(0), y, ...)) -> dus(x, add(ds(x), y), ...)).
4061+
if (!options_.is_layout_sensitive()) {
4062+
bool trivial_reshape;
4063+
std::vector<int64> deleted_dims;
4064+
std::vector<int64> inserted_dims;
4065+
4066+
HloInstruction* dus;
4067+
HloInstruction* slice;
4068+
std::tie(trivial_reshape, deleted_dims, inserted_dims) =
4069+
reshape->ReshapeMerelyInsertsOrDeletes1SizedDimensions();
4070+
// 1-sized dimensions added and removed will be one sized in both the update
4071+
// slice and the dynamic-update-slice result.
4072+
if (trivial_reshape &&
4073+
Match(reshape->mutable_operand(0),
4074+
m::Op(&dus)
4075+
.WithOpcode(HloOpcode::kDynamicUpdateSlice)
4076+
.WithOperand(1, m::Op(&slice))) &&
4077+
!dus->has_sharding() && !dus->operand(0)->has_sharding()) {
4078+
auto new_operand =
4079+
computation_->AddInstruction(HloInstruction::CreateReshape(
4080+
reshape->shape(), dus->mutable_operand(0)));
4081+
std::vector<int64> new_slice_shape;
4082+
std::vector<HloInstruction*> new_dus_operands;
4083+
new_dus_operands.push_back(new_operand);
4084+
new_dus_operands.push_back(nullptr);
4085+
auto zero = MakeScalarLike(dus->mutable_operand(2), 0);
4086+
const Shape& old_slice_shape = dus->operand(1)->shape();
4087+
for (int64 i = 0; i <= old_slice_shape.rank(); ++i) {
4088+
if (absl::c_linear_search(deleted_dims, i)) {
4089+
continue;
4090+
}
4091+
if (absl::c_linear_search(inserted_dims, new_slice_shape.size())) {
4092+
new_slice_shape.push_back(1);
4093+
new_dus_operands.push_back(zero);
4094+
}
4095+
if (i < old_slice_shape.rank()) {
4096+
new_slice_shape.push_back(old_slice_shape.dimensions(i));
4097+
new_dus_operands.push_back(dus->mutable_operand(2 + i));
4098+
}
4099+
}
4100+
auto new_slice =
4101+
computation_->AddInstruction(HloInstruction::CreateReshape(
4102+
ShapeUtil::MakeShape(old_slice_shape.element_type(),
4103+
new_slice_shape),
4104+
slice));
4105+
new_dus_operands[1] = new_slice;
4106+
auto new_dus =
4107+
dus->CloneWithNewOperands(reshape->shape(), new_dus_operands);
4108+
return ReplaceWithNewInstruction(reshape, std::move(new_dus));
4109+
}
4110+
}
4111+
40584112
// Make this a bitcast if possible.
40594113
if (HloInstruction* bitcast_operand =
40604114
BitcastingOperandOfReshapeOrCopyChain(reshape, options_)) {

tensorflow/compiler/xla/service/algebraic_simplifier_test.cc

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5562,6 +5562,41 @@ ENTRY AddBroadcastZeroWithDynamicSlice {
55625562
EXPECT_THAT(root->operand(1)->opcode(), HloOpcode::kPad);
55635563
}
55645564

5565+
TEST_F(AlgebraicSimplifierTest, AddDynamicUpdateSliceToAddSlice) {
5566+
const char* hlo_string = R"(
5567+
HloModule AddDynamicUpdateSliceToAddSlice
5568+
5569+
ENTRY AddDynamicUpdateSliceToAddSlice {
5570+
param0 = f32[1, 4,12,512] parameter(0)
5571+
constant = f32[] constant(0)
5572+
broadcast = f32[4,12,512] broadcast(constant), dimensions={}
5573+
param1 = f32[1,12,512] parameter(1)
5574+
param2 = s32[] parameter(2)
5575+
constant.1 = s32[] constant(0)
5576+
dynamic-update-slice = f32[4,12,512] dynamic-update-slice(
5577+
broadcast, param1, param2, constant.1, constant.1)
5578+
reshape = f32[1,4,12,512] reshape(dynamic-update-slice)
5579+
ROOT add = f32[1,4,12,512] add(param0, reshape)
5580+
}
5581+
)";
5582+
TF_ASSERT_OK_AND_ASSIGN(auto module,
5583+
ParseAndReturnVerifiedModule(hlo_string));
5584+
VLOG(2) << "Before rewrite reshape\n" << module->ToString();
5585+
AlgebraicSimplifier simplifier(default_options_);
5586+
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
5587+
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
5588+
VLOG(2) << "After rewrite to add slice\n" << module->ToString();
5589+
auto root = module->entry_computation()->root_instruction();
5590+
EXPECT_THAT(
5591+
root,
5592+
GmockMatch(m::DynamicUpdateSlice(
5593+
m::Parameter(0),
5594+
m::Add(m::DynamicSlice(m::Parameter(0), m::Constant(),
5595+
m::Parameter(2), m::Constant(), m::Constant()),
5596+
m::Reshape(m::Parameter(1))),
5597+
m::Constant(), m::Parameter(2), m::Constant(), m::Constant())));
5598+
}
5599+
55655600
TEST_F(AlgebraicSimplifierTest, ScalarMultiplyReduction) {
55665601
const char* hlo_string = R"(
55675602
HloModule ConstScalarMultiply

tensorflow/compiler/xla/service/pattern_matcher.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2227,6 +2227,7 @@ XLA_VARIADIC_OP_PATTERN(AfterAll);
22272227
XLA_VARIADIC_OP_PATTERN(Concatenate);
22282228
XLA_VARIADIC_OP_PATTERN(Conditional);
22292229
XLA_VARIADIC_OP_PATTERN(DynamicSlice)
2230+
XLA_VARIADIC_OP_PATTERN(DynamicUpdateSlice)
22302231
XLA_VARIADIC_OP_PATTERN(Fusion);
22312232
XLA_VARIADIC_OP_PATTERN(Map)
22322233
XLA_VARIADIC_OP_PATTERN(Reduce);

0 commit comments

Comments
 (0)