@@ -225,7 +225,7 @@ static absl::StatusOr<const se::CommandBuffer::Command*> Handle(
225225// ===----------------------------------------------------------------------===//
226226
227227TracedCommandBuffer::TracedCommandBuffer (const Command* trace_cmd,
228- Command::BufferUseVector buffers,
228+ Command::BufferUses buffers,
229229 int64_t capacity)
230230 : trace_cmd_(trace_cmd), capacity_(capacity), entries_(capacity) {
231231 CHECK_GT (capacity, 0 ) << " capacity must be larger than 0" ; // NOLINT
@@ -319,7 +319,8 @@ TracedCommandBufferCmd::RecordTracedCommand(
319319 this , command_buffer, [&] {
320320 const auto & debug_options = xla::GetDebugOptionsFromFlags ();
321321 return std::make_unique<TracedCommandBuffer>(
322- this , buffers (), debug_options.xla_cmd_buffer_trace_cache_size ());
322+ this , buffer_uses (),
323+ debug_options.xla_cmd_buffer_trace_cache_size ());
323324 });
324325
325326 TF_ASSIGN_OR_RETURN (
@@ -367,7 +368,7 @@ absl::StatusOr<const se::CommandBuffer::Command*> EmptyCmd::Record(
367368ComputationIdCmd::ComputationIdCmd (BufferAllocation::Slice dest, Kind kind)
368369 : Command(CommandType::kComputationIdCmd ), dest_(dest), kind_(kind) {}
369370
370- Command::BufferUseVector ComputationIdCmd::buffers () const {
371+ Command::BufferUses ComputationIdCmd::buffer_uses () const {
371372 return {BufferUse::Write (dest_, ShapeUtil::MakeShape (S32, {}))};
372373}
373374
@@ -506,8 +507,8 @@ absl::StatusOr<const se::CommandBuffer::Command*> LaunchCmd::Record(
506507 });
507508}
508509
509- Command::BufferUseVector LaunchCmd::buffers () const {
510- BufferUseVector buffers;
510+ Command::BufferUses LaunchCmd::buffer_uses () const {
511+ BufferUses buffers;
511512 for (int32_t i = 0 ; i < args_.size (); ++i) {
512513 buffers.emplace_back (args_[i].slice , args_access_[i], args_[i].shape );
513514 }
@@ -587,8 +588,8 @@ absl::StatusOr<const se::CommandBuffer::Command*> CustomKernelLaunchCmd::Record(
587588 });
588589}
589590
590- Command::BufferUseVector CustomKernelLaunchCmd::buffers () const {
591- BufferUseVector buffers;
591+ Command::BufferUses CustomKernelLaunchCmd::buffer_uses () const {
592+ BufferUses buffers;
592593 for (int32_t i = 0 ; i < args_.size (); ++i) {
593594 buffers.emplace_back (args_[i].slice , args_access_[i], args_[i].shape );
594595 }
@@ -643,7 +644,7 @@ MemcpyDeviceToDeviceCmd::Record(const Thunk::ExecuteParams& execute_params,
643644 });
644645}
645646
646- Command::BufferUseVector MemcpyDeviceToDeviceCmd::buffers () const {
647+ Command::BufferUses MemcpyDeviceToDeviceCmd::buffer_uses () const {
647648 return {BufferUse::Write (dst_.slice , dst_.shape ),
648649 BufferUse::Read (src_.slice , src_.shape )};
649650}
@@ -683,7 +684,7 @@ absl::StatusOr<const se::CommandBuffer::Command*> MemzeroCmd::Record(
683684 });
684685}
685686
686- Command::BufferUseVector MemzeroCmd::buffers () const {
687+ Command::BufferUses MemzeroCmd::buffer_uses () const {
687688 return {BufferUse::Write (dst_.slice , dst_.shape )};
688689}
689690
@@ -725,7 +726,7 @@ absl::StatusOr<const se::CommandBuffer::Command*> Memset32Cmd::Record(
725726 });
726727}
727728
728- Command::BufferUseVector Memset32Cmd::buffers () const {
729+ Command::BufferUses Memset32Cmd::buffer_uses () const {
729730 return {BufferUse::Write (dst_, ShapeUtil::MakeShape (U32, {}))};
730731}
731732
@@ -743,8 +744,9 @@ bool ChildCmd::requires_initialization() {
743744
744745bool ChildCmd::force_update () { return child_commands_.force_update (); }
745746
746- Command::BufferUseVector ChildCmd::buffers () const {
747- return {child_commands_.buffers ().begin (), child_commands_.buffers ().end ()};
747+ Command::BufferUses ChildCmd::buffer_uses () const {
748+ return {child_commands_.buffer_uses ().begin (),
749+ child_commands_.buffer_uses ().end ()};
748750}
749751
750752absl::Status ChildCmd::Initialize (const Thunk::InitializeParams& params) {
@@ -844,11 +846,11 @@ bool CaseCmd::force_update() {
844846 [](const auto & seq) { return seq.force_update (); });
845847}
846848
847- Command::BufferUseVector CaseCmd::buffers () const {
849+ Command::BufferUses CaseCmd::buffer_uses () const {
848850 absl::flat_hash_set<BufferUse> buffers;
849851 buffers.emplace (BufferUse::Read (index_.slice , index_.shape ));
850852 for (auto & branch : branches_) {
851- buffers.insert (branch.buffers ().begin (), branch.buffers ().end ());
853+ buffers.insert (branch.buffer_uses ().begin (), branch.buffer_uses ().end ());
852854 }
853855 return {buffers.begin (), buffers.end ()};
854856}
@@ -993,13 +995,13 @@ bool WhileCmd::force_update() {
993995 return cond_commands_.force_update () || body_commands_.force_update ();
994996}
995997
996- Command::BufferUseVector WhileCmd::buffers () const {
998+ Command::BufferUses WhileCmd::buffer_uses () const {
997999 absl::flat_hash_set<BufferUse> buffers;
9981000 buffers.emplace (BufferUse::Read (pred_, ShapeUtil::MakeShape (PRED, {})));
999- buffers.insert (cond_commands_.buffers ().begin (),
1000- cond_commands_.buffers ().end ());
1001- buffers.insert (body_commands_.buffers ().begin (),
1002- body_commands_.buffers ().end ());
1001+ buffers.insert (cond_commands_.buffer_uses ().begin (),
1002+ cond_commands_.buffer_uses ().end ());
1003+ buffers.insert (body_commands_.buffer_uses ().begin (),
1004+ body_commands_.buffer_uses ().end ());
10031005 return {buffers.begin (), buffers.end ()};
10041006}
10051007
@@ -1058,8 +1060,8 @@ absl::StatusOr<const se::CommandBuffer::Command*> GemmCmd::Record(
10581060 });
10591061}
10601062
1061- Command::BufferUseVector GemmCmd::buffers () const {
1062- Command::BufferUseVector res{
1063+ Command::BufferUses GemmCmd::buffer_uses () const {
1064+ Command::BufferUses res{
10631065 BufferUse::Read (lhs_buffer_, config_.lhs_layout .ToShape ()),
10641066 BufferUse::Read (rhs_buffer_, config_.rhs_layout .ToShape ()),
10651067 BufferUse::Write (output_buffer_, config_.output_layout .ToShape ()),
@@ -1123,8 +1125,8 @@ absl::StatusOr<const se::CommandBuffer::Command*> CublasLtCmd::Record(
11231125 });
11241126}
11251127
1126- Command::BufferUseVector CublasLtCmd::buffers () const {
1127- BufferUseVector buffer_usage;
1128+ Command::BufferUses CublasLtCmd::buffer_uses () const {
1129+ BufferUses buffer_usage;
11281130 buffer_usage.reserve (13 );
11291131 buffer_usage.push_back (BufferUse::Read (a_.slice , a_.shape ));
11301132 buffer_usage.push_back (BufferUse::Read (b_.slice , b_.shape ));
@@ -1213,8 +1215,8 @@ absl::StatusOr<const se::CommandBuffer::Command*> CuDnnCmd::Record(
12131215 });
12141216}
12151217
1216- Command::BufferUseVector CuDnnCmd::buffers () const {
1217- Command::BufferUseVector buffer_usage;
1218+ Command::BufferUses CuDnnCmd::buffer_uses () const {
1219+ Command::BufferUses buffer_usage;
12181220 buffer_usage.reserve (args_.size ());
12191221 for (int i = 0 ; i < args_.size () - 1 ; ++i) {
12201222 buffer_usage.push_back (BufferUse::Read (args_[i].slice , args_[i].shape ));
@@ -1388,8 +1390,8 @@ CustomCallCmd::RecordXlaFfiCall(const Thunk::ExecuteParams& execute_params,
13881390 });
13891391}
13901392
1391- Command::BufferUseVector CustomCallCmd::buffers () const {
1392- Command::BufferUseVector buffer_usage;
1393+ Command::BufferUses CustomCallCmd::buffer_uses () const {
1394+ Command::BufferUses buffer_usage;
13931395 for (auto & slices : {operands_, results_}) {
13941396 for (const std::optional<ShapedSlice>& slice : slices) {
13951397 if (slice.has_value ()) {
@@ -1536,8 +1538,8 @@ absl::StatusOr<const se::CommandBuffer::Command*> AllReduceCmd::Record(
15361538 });
15371539}
15381540
1539- Command::BufferUseVector AllReduceCmd::buffers () const {
1540- BufferUseVector buffer_usage;
1541+ Command::BufferUses AllReduceCmd::buffer_uses () const {
1542+ BufferUses buffer_usage;
15411543 for (auto & buffer : buffers_) {
15421544 buffer_usage.emplace_back (BufferUse::Read (buffer.source_buffer .slice ,
15431545 buffer.source_buffer .shape ));
@@ -1606,8 +1608,8 @@ absl::StatusOr<const se::CommandBuffer::Command*> ReduceScatterCmd::Record(
16061608 });
16071609}
16081610
1609- Command::BufferUseVector ReduceScatterCmd::buffers () const {
1610- BufferUseVector buffer_usage;
1611+ Command::BufferUses ReduceScatterCmd::buffer_uses () const {
1612+ BufferUses buffer_usage;
16111613 for (auto & buffer : buffers_) {
16121614 buffer_usage.emplace_back (BufferUse::Read (buffer.source_buffer .slice ,
16131615 buffer.source_buffer .shape ));
@@ -1677,8 +1679,8 @@ absl::StatusOr<const se::CommandBuffer::Command*> AllToAllCmd::Record(
16771679 });
16781680}
16791681
1680- Command::BufferUseVector AllToAllCmd::buffers () const {
1681- BufferUseVector buffer_usage;
1682+ Command::BufferUses AllToAllCmd::buffer_uses () const {
1683+ BufferUses buffer_usage;
16821684 for (auto & buffer : buffers_) {
16831685 buffer_usage.emplace_back (BufferUse::Read (buffer.source_buffer .slice ,
16841686 buffer.source_buffer .shape ));
@@ -1744,8 +1746,8 @@ absl::StatusOr<const se::CommandBuffer::Command*> AllGatherCmd::Record(
17441746 });
17451747}
17461748
1747- Command::BufferUseVector AllGatherCmd::buffers () const {
1748- BufferUseVector buffer_usage;
1749+ Command::BufferUses AllGatherCmd::buffer_uses () const {
1750+ BufferUses buffer_usage;
17491751 for (auto & buffer : buffers_) {
17501752 buffer_usage.emplace_back (BufferUse::Read (buffer.source_buffer .slice ,
17511753 buffer.source_buffer .shape ));
@@ -1811,8 +1813,8 @@ CollectiveBroadcastCmd::Record(const Thunk::ExecuteParams& execute_params,
18111813 });
18121814}
18131815
1814- Command::BufferUseVector CollectiveBroadcastCmd::buffers () const {
1815- BufferUseVector buffer_usage;
1816+ Command::BufferUses CollectiveBroadcastCmd::buffer_uses () const {
1817+ BufferUses buffer_usage;
18161818 for (auto & buffer : buffers_) {
18171819 buffer_usage.emplace_back (BufferUse::Read (buffer.source_buffer .slice ,
18181820 buffer.source_buffer .shape ));
@@ -1921,8 +1923,8 @@ absl::StatusOr<const se::CommandBuffer::Command*> RecvCmd::Record(
19211923 std::move (record_action), command_buffer, trace);
19221924}
19231925
1924- Command::BufferUseVector RecvCmd::buffers () const {
1925- BufferUseVector buffer_usage;
1926+ Command::BufferUses RecvCmd::buffer_uses () const {
1927+ BufferUses buffer_usage;
19261928 buffer_usage.emplace_back (BufferUse::Read (buffer_.source_buffer .slice ,
19271929 buffer_.source_buffer .shape ));
19281930 buffer_usage.emplace_back (BufferUse::Write (buffer_.destination_buffer .slice ,
@@ -2029,8 +2031,8 @@ absl::StatusOr<const se::CommandBuffer::Command*> SendCmd::Record(
20292031 std::move (record_action), command_buffer, trace);
20302032}
20312033
2032- Command::BufferUseVector SendCmd::buffers () const {
2033- BufferUseVector buffer_usage;
2034+ Command::BufferUses SendCmd::buffer_uses () const {
2035+ BufferUses buffer_usage;
20342036 buffer_usage.emplace_back (BufferUse::Read (buffer_.source_buffer .slice ,
20352037 buffer_.source_buffer .shape ));
20362038 buffer_usage.emplace_back (BufferUse::Write (buffer_.destination_buffer .slice ,
@@ -2111,8 +2113,8 @@ absl::StatusOr<const se::CommandBuffer::Command*> CollectivePermuteCmd::Record(
21112113 });
21122114}
21132115
2114- Command::BufferUseVector CollectivePermuteCmd::buffers () const {
2115- BufferUseVector buffer_usage;
2116+ Command::BufferUses CollectivePermuteCmd::buffer_uses () const {
2117+ BufferUses buffer_usage;
21162118 for (const CollectiveThunk::Buffer& buffer : buffers_) {
21172119 buffer_usage.emplace_back (BufferUse::Read (buffer.source_buffer .slice ,
21182120 buffer.source_buffer .shape ));
@@ -2426,9 +2428,9 @@ absl::StatusOr<const se::CommandBuffer::Command*> DynamicSliceFusionCmd::Record(
24262428 });
24272429}
24282430
2429- Command::BufferUseVector DynamicSliceFusionCmd::buffers () const {
2430- Command::BufferUseVector buffers;
2431- auto embed_buffers = embedded_commands_.buffers ();
2431+ Command::BufferUses DynamicSliceFusionCmd::buffer_uses () const {
2432+ Command::BufferUses buffers;
2433+ auto embed_buffers = embedded_commands_.buffer_uses ();
24322434 for (const BufferUse& buffer_usage : embed_buffers) {
24332435 buffers.emplace_back (
24342436 *embeded_to_origin_slice_map_.at (buffer_usage.slice ().index ()),
@@ -2502,8 +2504,8 @@ DynamicSliceCopyFusionCmd::Record(const Thunk::ExecuteParams& execute_params,
25022504 });
25032505}
25042506
2505- Command::BufferUseVector DynamicSliceCopyFusionCmd::buffers () const {
2506- Command::BufferUseVector buffers;
2507+ Command::BufferUses DynamicSliceCopyFusionCmd::buffer_uses () const {
2508+ Command::BufferUses buffers;
25072509 buffers.emplace_back (
25082510 BufferUse::Read (source_buffer_.slice , source_buffer_.shape ));
25092511 buffers.emplace_back (
0 commit comments