Skip to content

Commit 92e0315

Browse files
jurahultensorflower-gardener
authored andcommitted
[XLA:GPU] Support for multiple infeed managers per process.
- Enable multiple infeed managers per-process, to enable running multiple replicas within the same process (as exercised by SPMD tests). - To enable this, add the ability to attach a type-erased XLA specific state to the GpuExecutor - Define the XLA specific executor state for GPU to be the infeed manager instance. PiperOrigin-RevId: 363681266 Change-Id: Ia04f22db51700445a885e514a075029eb9b0be4f
1 parent d65be85 commit 92e0315

File tree

11 files changed

+155
-52
lines changed

11 files changed

+155
-52
lines changed

tensorflow/compiler/xla/service/gpu/BUILD

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1487,13 +1487,18 @@ cc_library(
14871487

14881488
cc_library(
14891489
name = "infeed_manager",
1490-
srcs = ["infeed_manager.cc"],
1490+
srcs = [
1491+
"infeed_manager.cc",
1492+
"xla_executor_state.h",
1493+
],
14911494
hdrs = ["infeed_manager.h"],
1495+
copts = if_cuda_is_configured(["-DGOOGLE_CUDA=1"]),
14921496
deps = [
14931497
":xfeed_queue",
14941498
"//tensorflow/compiler/xla:shape_tree",
14951499
"//tensorflow/compiler/xla:types",
14961500
"//tensorflow/core/platform:stream_executor_no_cuda",
1501+
"//tensorflow/stream_executor/gpu:gpu_executor_header",
14971502
"@com_google_absl//absl/base:core_headers",
14981503
"@com_google_absl//absl/memory",
14991504
],

tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ Status GpuTransferManager::TransferLiteralToInfeed(
6868

6969
Status GpuTransferManager::EnqueueBuffersToInfeed(
7070
se::StreamExecutor* executor, ShapeTree<InfeedBuffer> buffers) {
71-
gpu::InfeedManager* infeed_manager = gpu::GetOrCreateInfeedManager();
72-
se::Stream* stream = infeed_manager->GetStream(executor);
71+
gpu::InfeedManager* infeed_manager = gpu::GetOrCreateInfeedManager(executor);
72+
se::Stream* stream = infeed_manager->GetStream();
7373

7474
// TODO(b/30467474): Since this stream is shared across different
7575
// infeed requests, blocking on the stream might be
@@ -99,8 +99,8 @@ StatusOr<InfeedBuffer> GpuTransferManager::TransferBufferToInfeedInternal(
9999
return InvalidArgument("Infeed shape needs 0 bytes");
100100
}
101101

102-
gpu::InfeedManager* infeed_manager = gpu::GetOrCreateInfeedManager();
103-
se::Stream* stream = infeed_manager->GetStream(executor);
102+
gpu::InfeedManager* infeed_manager = gpu::GetOrCreateInfeedManager(executor);
103+
se::Stream* stream = infeed_manager->GetStream();
104104
if (stream == nullptr) {
105105
return InternalError("Failed to obtain a stream");
106106
}

tensorflow/compiler/xla/service/gpu/infeed_manager.cc

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,29 +17,29 @@ limitations under the License.
1717

1818
#include "absl/memory/memory.h"
1919

20+
#if GOOGLE_CUDA
21+
#include "tensorflow/compiler/xla/service/gpu/xla_executor_state.h"
22+
#include "tensorflow/stream_executor/gpu/gpu_executor.h"
23+
#endif // GOOGLE_CUDA
24+
2025
namespace xla {
2126
namespace gpu {
2227

23-
se::Stream* InfeedManager::GetStream(se::StreamExecutor* executor) {
24-
tensorflow::mutex_lock l(host_to_device_stream_mu_);
25-
if (host_to_device_executor_ == nullptr) {
26-
host_to_device_executor_ = executor;
27-
host_to_device_stream_ = absl::make_unique<se::Stream>(executor);
28-
host_to_device_stream_->Init();
29-
}
30-
31-
if (executor != host_to_device_executor_) {
32-
// The requested executor must be the same as the one for which
33-
// the stream is cached.
34-
return nullptr;
35-
}
36-
37-
return host_to_device_stream_.get();
28+
InfeedManager::InfeedManager(se::StreamExecutor *executor)
29+
: stream_(absl::make_unique<se::Stream>(executor)) {
30+
stream_->Init();
3831
}
3932

40-
InfeedManager* GetOrCreateInfeedManager() {
41-
static InfeedManager* manager = new InfeedManager;
42-
return manager;
33+
InfeedManager *GetOrCreateInfeedManager(se::StreamExecutor *executor) {
34+
#if GOOGLE_CUDA
35+
stream_executor::gpu::GpuExecutor *gpu_executor =
36+
stream_executor::gpu::ExtractGpuExecutor(executor);
37+
auto *xla_state =
38+
gpu_executor->getOrCreateXLAState<GpuExecutorXLAState>(executor);
39+
return xla_state->getOrCreateInfeedManager(executor);
40+
#else // GOOGLE_CUDA
41+
return nullptr;
42+
#endif // GOOGLE_CUDA
4343
}
4444

4545
} // namespace gpu

tensorflow/compiler/xla/service/gpu/infeed_manager.h

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -64,26 +64,18 @@ class InfeedBuffer {
6464
// Client-side class used to enqueue infeed buffers.
6565
class InfeedManager : public XfeedQueue<ShapeTree<InfeedBuffer>> {
6666
public:
67-
// Returns a cached stream associated with an executor. Allocates a
68-
// new stream on the first invocation. On subsequent invocations, if
69-
// the cached executor is not the same as the requested executor,
70-
// returns null.
71-
se::Stream* GetStream(se::StreamExecutor* executor);
67+
explicit InfeedManager(se::StreamExecutor* executor);
7268

73-
private:
74-
// Mutex for serializing the creation of host_to_device_stream_.
75-
tensorflow::mutex host_to_device_stream_mu_;
76-
77-
// Cached host to device stream for queuing infeed data.
78-
std::unique_ptr<se::Stream> host_to_device_stream_
79-
ABSL_GUARDED_BY(host_to_device_stream_mu_);
69+
// Returns a stream for this infeed manager.
70+
se::Stream* GetStream() const { return stream_.get(); }
8071

81-
// Executor that the host_to_device_stream belongs to. Not owned.
82-
se::StreamExecutor* host_to_device_executor_ = nullptr;
72+
private:
73+
// Stream used to enqueue infeed device copies.
74+
std::unique_ptr<se::Stream> stream_;
8375
};
8476

85-
// Singleton creator-or-accessor: Returns the GPU infeed manager.
86-
InfeedManager* GetOrCreateInfeedManager();
77+
// Returns the GPU infeed manager for the given stream executor,
78+
InfeedManager* GetOrCreateInfeedManager(se::StreamExecutor* executor);
8779

8880
} // namespace gpu
8981
} // namespace xla

tensorflow/compiler/xla/service/gpu/infeed_thunk.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License.
1616
#include "tensorflow/compiler/xla/service/gpu/infeed_thunk.h"
1717

1818
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
19+
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
1920
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
2021
#include "tensorflow/compiler/xla/service/gpu/infeed_manager.h"
2122
#include "tensorflow/compiler/xla/shape_util.h"
@@ -31,15 +32,16 @@ InfeedThunk::InfeedThunk(ThunkInfo thunk_info,
3132
: Thunk(Kind::kInfeed, thunk_info), dest_slices_(std::move(dest_slices)) {}
3233

3334
Status InfeedThunk::ExecuteOnStream(const ExecuteParams& params) {
34-
auto& stream = *params.stream;
35-
auto& buffer_allocations = *params.buffer_allocations;
35+
se::Stream& stream = *params.stream;
36+
const BufferAllocations& buffer_allocations = *params.buffer_allocations;
3637

3738
VLOG(2) << "Infeeding to GPU";
3839

3940
auto op_profiler =
4041
params.profiler->MakeScopedInstructionProfiler(profile_index());
42+
4143
ShapeTree<InfeedBuffer> source_buffers =
42-
GetOrCreateInfeedManager()->BlockingGetNextDestination();
44+
GetOrCreateInfeedManager(stream.parent())->BlockingGetNextDestination();
4345

4446
size_t index = 0;
4547
for (auto& source : source_buffers.leaves()) {
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_XLA_EXECUTOR_STATE_H_
17+
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_XLA_EXECUTOR_STATE_H_
18+
19+
#include "tensorflow/compiler/xla/service/gpu/infeed_manager.h"
20+
21+
// Defines XLA:GPU specific state that will be attached to the GpuExecutor.
22+
23+
namespace xla {
24+
namespace gpu {
25+
26+
class GpuExecutorXLAState {
27+
public:
28+
explicit GpuExecutorXLAState(stream_executor::StreamExecutor *) {}
29+
InfeedManager *getOrCreateInfeedManager(stream_executor::StreamExecutor *se) {
30+
tensorflow::mutex_lock l(mu_);
31+
if (!infeed_manager_) {
32+
infeed_manager_ = std::make_unique<InfeedManager>(se);
33+
}
34+
return infeed_manager_.get();
35+
}
36+
37+
private:
38+
tensorflow::mutex mu_;
39+
std::unique_ptr<InfeedManager> infeed_manager_ ABSL_GUARDED_BY(mu_);
40+
};
41+
42+
} // namespace gpu
43+
} // namespace xla
44+
45+
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_XLA_EXECUTOR_STATE_H_

tensorflow/stream_executor/cuda/cuda_gpu_executor.cc

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,6 @@ GpuContext* ExtractGpuContext(GpuExecutor* cuda_exec) {
115115
return cuda_exec->gpu_context();
116116
}
117117

118-
GpuExecutor* ExtractGpuExecutor(StreamExecutor* stream_exec) {
119-
return static_cast<GpuExecutor*>(stream_exec->implementation());
120-
}
121-
122118
GpuExecutor::~GpuExecutor() {
123119
CHECK(kernel_to_gpu_binary_.empty()) << "GpuExecutor has live kernels.";
124120
CHECK(gpu_binary_to_module_.empty()) << "GpuExecutor has loaded modules.";

tensorflow/stream_executor/gpu/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ cc_library(
4545
srcs = if_gpu_is_configured(["gpu_activation.cc"]),
4646
hdrs = if_gpu_is_configured(["gpu_activation.h"]),
4747
deps = if_gpu_is_configured([
48+
":gpu_executor_header",
4849
":gpu_activation_header",
4950
":gpu_driver_header",
5051
"//tensorflow/stream_executor",
@@ -109,6 +110,7 @@ cc_library(
109110
"//tensorflow/stream_executor:event",
110111
"//tensorflow/stream_executor:platform",
111112
"//tensorflow/stream_executor:stream_executor_internal",
113+
"//tensorflow/stream_executor:stream_executor_pimpl_header",
112114
"//tensorflow/stream_executor/lib",
113115
"//tensorflow/stream_executor/platform",
114116
"@com_google_absl//absl/strings",

tensorflow/stream_executor/gpu/gpu_activation.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,14 @@ limitations under the License.
1616
#include "tensorflow/stream_executor/gpu/gpu_activation.h"
1717

1818
#include "tensorflow/stream_executor/gpu/gpu_driver.h"
19+
#include "tensorflow/stream_executor/gpu/gpu_executor.h"
1920
#include "tensorflow/stream_executor/stream_executor.h"
2021
#include "tensorflow/stream_executor/stream_executor_internal.h"
2122

2223
namespace stream_executor {
2324
namespace gpu {
2425

2526
GpuContext* ExtractGpuContext(GpuExecutor* gpu_exec);
26-
GpuExecutor* ExtractGpuExecutor(StreamExecutor* stream_exec);
2727

2828
ScopedActivateExecutorContext::ScopedActivateExecutorContext(
2929
GpuExecutor* gpu_exec)

tensorflow/stream_executor/gpu/gpu_executor.h

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,13 @@ limitations under the License.
2222
#ifndef TENSORFLOW_STREAM_EXECUTOR_GPU_GPU_EXECUTOR_H_
2323
#define TENSORFLOW_STREAM_EXECUTOR_GPU_GPU_EXECUTOR_H_
2424

25+
#include <memory>
2526
#include <set>
27+
#include <type_traits>
2628
#include <unordered_map>
2729

2830
#include "absl/strings/string_view.h"
29-
#include "absl/synchronization/mutex.h"
31+
#include "tensorflow/core/platform/mutex.h"
3032
#include "tensorflow/core/platform/thread_annotations.h"
3133
#include "tensorflow/stream_executor/event.h"
3234
#include "tensorflow/stream_executor/gpu/gpu_kernel.h"
@@ -35,13 +37,55 @@ limitations under the License.
3537
#include "tensorflow/stream_executor/platform.h"
3638
#include "tensorflow/stream_executor/platform/port.h"
3739
#include "tensorflow/stream_executor/stream_executor_internal.h"
40+
#include "tensorflow/stream_executor/stream_executor_pimpl.h"
3841

3942
namespace stream_executor {
43+
44+
class StreamExecutor;
45+
4046
namespace gpu {
4147

48+
// Pointer-to-implementation object type with virtual destruction for any XLA
49+
// specific data hanging off of the GpuExecutor.
50+
class XLAInterface {
51+
public:
52+
// Default constructor for the abstract interface.
53+
explicit XLAInterface() {}
54+
55+
// Default destructor for the abstract interface.
56+
virtual ~XLAInterface() {}
57+
};
58+
4259
// CUDA-platform implementation of the platform-agnostic
4360
// StreamExecutorInterface.
4461
class GpuExecutor : public internal::StreamExecutorInterface {
62+
// Helper classes to attach a type erased state to the GpuExecutor. Currently,
63+
// we just need to support some XLA specific state.
64+
class Object {
65+
struct Concept {
66+
virtual ~Concept() {}
67+
};
68+
template <typename T>
69+
struct Model : Concept {
70+
explicit Model(StreamExecutor* se) : object(se) {}
71+
T object;
72+
};
73+
74+
public:
75+
template <typename T>
76+
T* getOrCreate(StreamExecutor* se) {
77+
tensorflow::mutex_lock l(mu_);
78+
if (!object_) {
79+
object_ = std::make_unique<Model<T>>(se);
80+
}
81+
return &(dynamic_cast<Model<T>*>(object_.get())->object);
82+
}
83+
84+
private:
85+
tensorflow::mutex mu_;
86+
std::unique_ptr<Concept> object_ ABSL_GUARDED_BY(mu_);
87+
};
88+
4589
public:
4690
// sub_platform indicates the subplatform used in this executor; it must
4791
// be a CUDA type.
@@ -233,6 +277,20 @@ class GpuExecutor : public internal::StreamExecutorInterface {
233277

234278
GpuContext* gpu_context();
235279

280+
// Provide a type-erased way of attaching arbitrary XLA specific state to the
281+
// GpuExecutor. XLA based execution will use this method to attach per-stream
282+
// executor XLA specific objects (like the Infeed and Outfeed managers) to the
283+
// stream executor, so that their lifetimes can be tied to the lifetime of the
284+
// stream executor for which that object is allocated for. This simplifies
285+
// memory management as compared to having these objects reside on the side
286+
// and then either leaking or having to implement callbacks that the SE
287+
// destructors call to deallocate any side state that is associated with that
288+
// SE object.
289+
template <typename T>
290+
T* getOrCreateXLAState(StreamExecutor* se) {
291+
return xla_state_.getOrCreate<T>(se);
292+
}
293+
236294
private:
237295
// Attempts to find a more specific version of the file indicated by
238296
// filename by looking for compute-capability-specific suffixed versions; i.e.
@@ -337,9 +395,16 @@ class GpuExecutor : public internal::StreamExecutorInterface {
337395
// The plugin configuration associated with this instance.
338396
PluginConfig plugin_config_;
339397

398+
// Type erased XLA specific state attached to GpuExecutor.
399+
Object xla_state_;
400+
340401
SE_DISALLOW_COPY_AND_ASSIGN(GpuExecutor);
341402
};
342403

404+
inline GpuExecutor* ExtractGpuExecutor(StreamExecutor* stream_exec) {
405+
return static_cast<GpuExecutor*>(stream_exec->implementation());
406+
}
407+
343408
} // namespace gpu
344409
} // namespace stream_executor
345410

0 commit comments

Comments
 (0)