Skip to content

Commit a00e63e

Browse files
ezhulenevGoogle-ML-Automation
authored andcommitted
PR #37252: [xla:gpu] Unify buffer_uses and buffers beteen Thunk and Command
Imported from GitHub PR #37252 One more step towards unifying commands and thunks Copybara import of the project: -- f84f5b0 by Eugene Zhulenev <ezhulenev@openxla.org>: [xla:gpu] Unify buffer_uses and buffers beteen Thunk and Command Merging this change closes #37252 COPYBARA_INTEGRATE_REVIEW=#37252 from ezhulenev:unify-cmd-and-thunk-0 f84f5b0 PiperOrigin-RevId: 865535090
1 parent 8effab0 commit a00e63e

File tree

8 files changed

+140
-137
lines changed

8 files changed

+140
-137
lines changed

xla/backends/gpu/runtime/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,7 @@ xla_test(
246246
":shaped_slice",
247247
":thunk",
248248
"//xla:shape_util",
249+
"//xla:xla_data_proto_cc",
249250
"//xla/runtime:buffer_use",
250251
"//xla/service:buffer_assignment",
251252
"//xla/service:executable",
@@ -2781,6 +2782,7 @@ cc_library(
27812782
"//xla/ffi:execution_context",
27822783
"//xla/hlo/ir:hlo",
27832784
"//xla/runtime:buffer_use",
2785+
"//xla/runtime:resource_use",
27842786
"//xla/service:buffer_assignment",
27852787
"//xla/service:executable",
27862788
"//xla/service/gpu:backend_configs_cc",

xla/backends/gpu/runtime/command.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ bool IsCollectiveCommand(CommandType type);
122122
// done with a state manager.
123123
class Command {
124124
public:
125-
using BufferUseVector = absl::InlinedVector<BufferUse, 4>;
125+
using BufferUses = Thunk::BufferUses;
126126
using ResourceUseVector = absl::InlinedVector<ResourceUse, 1>;
127127

128128
public:
@@ -146,7 +146,7 @@ class Command {
146146
// rely on this information to skip unnecessary updates.
147147
std::optional<std::vector<BufferAllocation::Index>> updated_allocs;
148148

149-
// A flag indicating whether we record comands at command buffer thunk
149+
// A flag indicating whether we record commands at command buffer thunk
150150
// initialization time.
151151
bool is_initialization = false;
152152

@@ -213,7 +213,7 @@ class Command {
213213

214214
// Returns true if command supports loop unroll, the while loop can be
215215
// unrolled only if it has pre-known trip count and also all commands from the
216-
// body commands are unrollable..
216+
// body commands are unrollable.
217217
virtual bool support_loop_unroll() { return true; }
218218

219219
// This is only true for DynamicSliceCopyFusionCmd when offset is dependents
@@ -224,7 +224,7 @@ class Command {
224224

225225
// Returns all buffers used by the cmd. These will be used to track cmd
226226
// updates, thus they need to be consistent across calls to the function.
227-
virtual BufferUseVector buffers() const { return {}; }
227+
virtual BufferUses buffer_uses() const { return {}; }
228228

229229
std::shared_ptr<Resource> token() const { return token_; }
230230

xla/backends/gpu/runtime/command_buffer_cmd.cc

Lines changed: 50 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ static absl::StatusOr<const se::CommandBuffer::Command*> Handle(
225225
//===----------------------------------------------------------------------===//
226226

227227
TracedCommandBuffer::TracedCommandBuffer(const Command* trace_cmd,
228-
Command::BufferUseVector buffers,
228+
Command::BufferUses buffers,
229229
int64_t capacity)
230230
: trace_cmd_(trace_cmd), capacity_(capacity), entries_(capacity) {
231231
CHECK_GT(capacity, 0) << "capacity must be larger than 0"; // NOLINT
@@ -319,7 +319,8 @@ TracedCommandBufferCmd::RecordTracedCommand(
319319
this, command_buffer, [&] {
320320
const auto& debug_options = xla::GetDebugOptionsFromFlags();
321321
return std::make_unique<TracedCommandBuffer>(
322-
this, buffers(), debug_options.xla_cmd_buffer_trace_cache_size());
322+
this, buffer_uses(),
323+
debug_options.xla_cmd_buffer_trace_cache_size());
323324
});
324325

325326
TF_ASSIGN_OR_RETURN(
@@ -367,7 +368,7 @@ absl::StatusOr<const se::CommandBuffer::Command*> EmptyCmd::Record(
367368
ComputationIdCmd::ComputationIdCmd(BufferAllocation::Slice dest, Kind kind)
368369
: Command(CommandType::kComputationIdCmd), dest_(dest), kind_(kind) {}
369370

370-
Command::BufferUseVector ComputationIdCmd::buffers() const {
371+
Command::BufferUses ComputationIdCmd::buffer_uses() const {
371372
return {BufferUse::Write(dest_, ShapeUtil::MakeShape(S32, {}))};
372373
}
373374

@@ -506,8 +507,8 @@ absl::StatusOr<const se::CommandBuffer::Command*> LaunchCmd::Record(
506507
});
507508
}
508509

509-
Command::BufferUseVector LaunchCmd::buffers() const {
510-
BufferUseVector buffers;
510+
Command::BufferUses LaunchCmd::buffer_uses() const {
511+
BufferUses buffers;
511512
for (int32_t i = 0; i < args_.size(); ++i) {
512513
buffers.emplace_back(args_[i].slice, args_access_[i], args_[i].shape);
513514
}
@@ -587,8 +588,8 @@ absl::StatusOr<const se::CommandBuffer::Command*> CustomKernelLaunchCmd::Record(
587588
});
588589
}
589590

590-
Command::BufferUseVector CustomKernelLaunchCmd::buffers() const {
591-
BufferUseVector buffers;
591+
Command::BufferUses CustomKernelLaunchCmd::buffer_uses() const {
592+
BufferUses buffers;
592593
for (int32_t i = 0; i < args_.size(); ++i) {
593594
buffers.emplace_back(args_[i].slice, args_access_[i], args_[i].shape);
594595
}
@@ -643,7 +644,7 @@ MemcpyDeviceToDeviceCmd::Record(const Thunk::ExecuteParams& execute_params,
643644
});
644645
}
645646

646-
Command::BufferUseVector MemcpyDeviceToDeviceCmd::buffers() const {
647+
Command::BufferUses MemcpyDeviceToDeviceCmd::buffer_uses() const {
647648
return {BufferUse::Write(dst_.slice, dst_.shape),
648649
BufferUse::Read(src_.slice, src_.shape)};
649650
}
@@ -683,7 +684,7 @@ absl::StatusOr<const se::CommandBuffer::Command*> MemzeroCmd::Record(
683684
});
684685
}
685686

686-
Command::BufferUseVector MemzeroCmd::buffers() const {
687+
Command::BufferUses MemzeroCmd::buffer_uses() const {
687688
return {BufferUse::Write(dst_.slice, dst_.shape)};
688689
}
689690

@@ -725,7 +726,7 @@ absl::StatusOr<const se::CommandBuffer::Command*> Memset32Cmd::Record(
725726
});
726727
}
727728

728-
Command::BufferUseVector Memset32Cmd::buffers() const {
729+
Command::BufferUses Memset32Cmd::buffer_uses() const {
729730
return {BufferUse::Write(dst_, ShapeUtil::MakeShape(U32, {}))};
730731
}
731732

@@ -743,8 +744,9 @@ bool ChildCmd::requires_initialization() {
743744

744745
bool ChildCmd::force_update() { return child_commands_.force_update(); }
745746

746-
Command::BufferUseVector ChildCmd::buffers() const {
747-
return {child_commands_.buffers().begin(), child_commands_.buffers().end()};
747+
Command::BufferUses ChildCmd::buffer_uses() const {
748+
return {child_commands_.buffer_uses().begin(),
749+
child_commands_.buffer_uses().end()};
748750
}
749751

750752
absl::Status ChildCmd::Initialize(const Thunk::InitializeParams& params) {
@@ -844,11 +846,11 @@ bool CaseCmd::force_update() {
844846
[](const auto& seq) { return seq.force_update(); });
845847
}
846848

847-
Command::BufferUseVector CaseCmd::buffers() const {
849+
Command::BufferUses CaseCmd::buffer_uses() const {
848850
absl::flat_hash_set<BufferUse> buffers;
849851
buffers.emplace(BufferUse::Read(index_.slice, index_.shape));
850852
for (auto& branch : branches_) {
851-
buffers.insert(branch.buffers().begin(), branch.buffers().end());
853+
buffers.insert(branch.buffer_uses().begin(), branch.buffer_uses().end());
852854
}
853855
return {buffers.begin(), buffers.end()};
854856
}
@@ -993,13 +995,13 @@ bool WhileCmd::force_update() {
993995
return cond_commands_.force_update() || body_commands_.force_update();
994996
}
995997

996-
Command::BufferUseVector WhileCmd::buffers() const {
998+
Command::BufferUses WhileCmd::buffer_uses() const {
997999
absl::flat_hash_set<BufferUse> buffers;
9981000
buffers.emplace(BufferUse::Read(pred_, ShapeUtil::MakeShape(PRED, {})));
999-
buffers.insert(cond_commands_.buffers().begin(),
1000-
cond_commands_.buffers().end());
1001-
buffers.insert(body_commands_.buffers().begin(),
1002-
body_commands_.buffers().end());
1001+
buffers.insert(cond_commands_.buffer_uses().begin(),
1002+
cond_commands_.buffer_uses().end());
1003+
buffers.insert(body_commands_.buffer_uses().begin(),
1004+
body_commands_.buffer_uses().end());
10031005
return {buffers.begin(), buffers.end()};
10041006
}
10051007

@@ -1058,8 +1060,8 @@ absl::StatusOr<const se::CommandBuffer::Command*> GemmCmd::Record(
10581060
});
10591061
}
10601062

1061-
Command::BufferUseVector GemmCmd::buffers() const {
1062-
Command::BufferUseVector res{
1063+
Command::BufferUses GemmCmd::buffer_uses() const {
1064+
Command::BufferUses res{
10631065
BufferUse::Read(lhs_buffer_, config_.lhs_layout.ToShape()),
10641066
BufferUse::Read(rhs_buffer_, config_.rhs_layout.ToShape()),
10651067
BufferUse::Write(output_buffer_, config_.output_layout.ToShape()),
@@ -1123,8 +1125,8 @@ absl::StatusOr<const se::CommandBuffer::Command*> CublasLtCmd::Record(
11231125
});
11241126
}
11251127

1126-
Command::BufferUseVector CublasLtCmd::buffers() const {
1127-
BufferUseVector buffer_usage;
1128+
Command::BufferUses CublasLtCmd::buffer_uses() const {
1129+
BufferUses buffer_usage;
11281130
buffer_usage.reserve(13);
11291131
buffer_usage.push_back(BufferUse::Read(a_.slice, a_.shape));
11301132
buffer_usage.push_back(BufferUse::Read(b_.slice, b_.shape));
@@ -1213,8 +1215,8 @@ absl::StatusOr<const se::CommandBuffer::Command*> CuDnnCmd::Record(
12131215
});
12141216
}
12151217

1216-
Command::BufferUseVector CuDnnCmd::buffers() const {
1217-
Command::BufferUseVector buffer_usage;
1218+
Command::BufferUses CuDnnCmd::buffer_uses() const {
1219+
Command::BufferUses buffer_usage;
12181220
buffer_usage.reserve(args_.size());
12191221
for (int i = 0; i < args_.size() - 1; ++i) {
12201222
buffer_usage.push_back(BufferUse::Read(args_[i].slice, args_[i].shape));
@@ -1388,8 +1390,8 @@ CustomCallCmd::RecordXlaFfiCall(const Thunk::ExecuteParams& execute_params,
13881390
});
13891391
}
13901392

1391-
Command::BufferUseVector CustomCallCmd::buffers() const {
1392-
Command::BufferUseVector buffer_usage;
1393+
Command::BufferUses CustomCallCmd::buffer_uses() const {
1394+
Command::BufferUses buffer_usage;
13931395
for (auto& slices : {operands_, results_}) {
13941396
for (const std::optional<ShapedSlice>& slice : slices) {
13951397
if (slice.has_value()) {
@@ -1536,8 +1538,8 @@ absl::StatusOr<const se::CommandBuffer::Command*> AllReduceCmd::Record(
15361538
});
15371539
}
15381540

1539-
Command::BufferUseVector AllReduceCmd::buffers() const {
1540-
BufferUseVector buffer_usage;
1541+
Command::BufferUses AllReduceCmd::buffer_uses() const {
1542+
BufferUses buffer_usage;
15411543
for (auto& buffer : buffers_) {
15421544
buffer_usage.emplace_back(BufferUse::Read(buffer.source_buffer.slice,
15431545
buffer.source_buffer.shape));
@@ -1606,8 +1608,8 @@ absl::StatusOr<const se::CommandBuffer::Command*> ReduceScatterCmd::Record(
16061608
});
16071609
}
16081610

1609-
Command::BufferUseVector ReduceScatterCmd::buffers() const {
1610-
BufferUseVector buffer_usage;
1611+
Command::BufferUses ReduceScatterCmd::buffer_uses() const {
1612+
BufferUses buffer_usage;
16111613
for (auto& buffer : buffers_) {
16121614
buffer_usage.emplace_back(BufferUse::Read(buffer.source_buffer.slice,
16131615
buffer.source_buffer.shape));
@@ -1677,8 +1679,8 @@ absl::StatusOr<const se::CommandBuffer::Command*> AllToAllCmd::Record(
16771679
});
16781680
}
16791681

1680-
Command::BufferUseVector AllToAllCmd::buffers() const {
1681-
BufferUseVector buffer_usage;
1682+
Command::BufferUses AllToAllCmd::buffer_uses() const {
1683+
BufferUses buffer_usage;
16821684
for (auto& buffer : buffers_) {
16831685
buffer_usage.emplace_back(BufferUse::Read(buffer.source_buffer.slice,
16841686
buffer.source_buffer.shape));
@@ -1744,8 +1746,8 @@ absl::StatusOr<const se::CommandBuffer::Command*> AllGatherCmd::Record(
17441746
});
17451747
}
17461748

1747-
Command::BufferUseVector AllGatherCmd::buffers() const {
1748-
BufferUseVector buffer_usage;
1749+
Command::BufferUses AllGatherCmd::buffer_uses() const {
1750+
BufferUses buffer_usage;
17491751
for (auto& buffer : buffers_) {
17501752
buffer_usage.emplace_back(BufferUse::Read(buffer.source_buffer.slice,
17511753
buffer.source_buffer.shape));
@@ -1811,8 +1813,8 @@ CollectiveBroadcastCmd::Record(const Thunk::ExecuteParams& execute_params,
18111813
});
18121814
}
18131815

1814-
Command::BufferUseVector CollectiveBroadcastCmd::buffers() const {
1815-
BufferUseVector buffer_usage;
1816+
Command::BufferUses CollectiveBroadcastCmd::buffer_uses() const {
1817+
BufferUses buffer_usage;
18161818
for (auto& buffer : buffers_) {
18171819
buffer_usage.emplace_back(BufferUse::Read(buffer.source_buffer.slice,
18181820
buffer.source_buffer.shape));
@@ -1921,8 +1923,8 @@ absl::StatusOr<const se::CommandBuffer::Command*> RecvCmd::Record(
19211923
std::move(record_action), command_buffer, trace);
19221924
}
19231925

1924-
Command::BufferUseVector RecvCmd::buffers() const {
1925-
BufferUseVector buffer_usage;
1926+
Command::BufferUses RecvCmd::buffer_uses() const {
1927+
BufferUses buffer_usage;
19261928
buffer_usage.emplace_back(BufferUse::Read(buffer_.source_buffer.slice,
19271929
buffer_.source_buffer.shape));
19281930
buffer_usage.emplace_back(BufferUse::Write(buffer_.destination_buffer.slice,
@@ -2029,8 +2031,8 @@ absl::StatusOr<const se::CommandBuffer::Command*> SendCmd::Record(
20292031
std::move(record_action), command_buffer, trace);
20302032
}
20312033

2032-
Command::BufferUseVector SendCmd::buffers() const {
2033-
BufferUseVector buffer_usage;
2034+
Command::BufferUses SendCmd::buffer_uses() const {
2035+
BufferUses buffer_usage;
20342036
buffer_usage.emplace_back(BufferUse::Read(buffer_.source_buffer.slice,
20352037
buffer_.source_buffer.shape));
20362038
buffer_usage.emplace_back(BufferUse::Write(buffer_.destination_buffer.slice,
@@ -2111,8 +2113,8 @@ absl::StatusOr<const se::CommandBuffer::Command*> CollectivePermuteCmd::Record(
21112113
});
21122114
}
21132115

2114-
Command::BufferUseVector CollectivePermuteCmd::buffers() const {
2115-
BufferUseVector buffer_usage;
2116+
Command::BufferUses CollectivePermuteCmd::buffer_uses() const {
2117+
BufferUses buffer_usage;
21162118
for (const CollectiveThunk::Buffer& buffer : buffers_) {
21172119
buffer_usage.emplace_back(BufferUse::Read(buffer.source_buffer.slice,
21182120
buffer.source_buffer.shape));
@@ -2426,9 +2428,9 @@ absl::StatusOr<const se::CommandBuffer::Command*> DynamicSliceFusionCmd::Record(
24262428
});
24272429
}
24282430

2429-
Command::BufferUseVector DynamicSliceFusionCmd::buffers() const {
2430-
Command::BufferUseVector buffers;
2431-
auto embed_buffers = embedded_commands_.buffers();
2431+
Command::BufferUses DynamicSliceFusionCmd::buffer_uses() const {
2432+
Command::BufferUses buffers;
2433+
auto embed_buffers = embedded_commands_.buffer_uses();
24322434
for (const BufferUse& buffer_usage : embed_buffers) {
24332435
buffers.emplace_back(
24342436
*embeded_to_origin_slice_map_.at(buffer_usage.slice().index()),
@@ -2502,8 +2504,8 @@ DynamicSliceCopyFusionCmd::Record(const Thunk::ExecuteParams& execute_params,
25022504
});
25032505
}
25042506

2505-
Command::BufferUseVector DynamicSliceCopyFusionCmd::buffers() const {
2506-
Command::BufferUseVector buffers;
2507+
Command::BufferUses DynamicSliceCopyFusionCmd::buffer_uses() const {
2508+
Command::BufferUses buffers;
25072509
buffers.emplace_back(
25082510
BufferUse::Read(source_buffer_.slice, source_buffer_.shape));
25092511
buffers.emplace_back(

0 commit comments

Comments
 (0)