@@ -22,11 +22,13 @@ limitations under the License.
2222#ifndef TENSORFLOW_STREAM_EXECUTOR_GPU_GPU_EXECUTOR_H_
2323#define TENSORFLOW_STREAM_EXECUTOR_GPU_GPU_EXECUTOR_H_
2424
25+ #include < memory>
2526#include < set>
27+ #include < type_traits>
2628#include < unordered_map>
2729
2830#include " absl/strings/string_view.h"
29- #include " absl/synchronization /mutex.h"
31+ #include " tensorflow/core/platform /mutex.h"
3032#include " tensorflow/core/platform/thread_annotations.h"
3133#include " tensorflow/stream_executor/event.h"
3234#include " tensorflow/stream_executor/gpu/gpu_kernel.h"
@@ -35,13 +37,55 @@ limitations under the License.
3537#include " tensorflow/stream_executor/platform.h"
3638#include " tensorflow/stream_executor/platform/port.h"
3739#include " tensorflow/stream_executor/stream_executor_internal.h"
40+ #include " tensorflow/stream_executor/stream_executor_pimpl.h"
3841
3942namespace stream_executor {
43+
44+ class StreamExecutor ;
45+
4046namespace gpu {
4147
48+ // Pointer-to-implementation object type with virtual destruction for any XLA
49+ // specific data hanging off of the GpuExecutor.
50+ class XLAInterface {
51+ public:
52+ // Default constructor for the abstract interface.
53+ explicit XLAInterface () {}
54+
55+ // Default destructor for the abstract interface.
56+ virtual ~XLAInterface () {}
57+ };
58+
4259// CUDA-platform implementation of the platform-agnostic
4360// StreamExecutorInterface.
4461class GpuExecutor : public internal ::StreamExecutorInterface {
62+ // Helper classes to attach a type erased state to the GpuExecutor. Currently,
63+ // we just need to support some XLA specific state.
64+ class Object {
65+ struct Concept {
66+ virtual ~Concept () {}
67+ };
68+ template <typename T>
69+ struct Model : Concept {
70+ explicit Model (StreamExecutor* se) : object(se) {}
71+ T object;
72+ };
73+
74+ public:
75+ template <typename T>
76+ T* getOrCreate (StreamExecutor* se) {
77+ tensorflow::mutex_lock l (mu_);
78+ if (!object_) {
79+ object_ = std::make_unique<Model<T>>(se);
80+ }
81+ return &(dynamic_cast <Model<T>*>(object_.get ())->object );
82+ }
83+
84+ private:
85+ tensorflow::mutex mu_;
86+ std::unique_ptr<Concept> object_ ABSL_GUARDED_BY (mu_);
87+ };
88+
4589 public:
4690 // sub_platform indicates the subplatform used in this executor; it must
4791 // be a CUDA type.
@@ -233,6 +277,20 @@ class GpuExecutor : public internal::StreamExecutorInterface {
233277
234278 GpuContext* gpu_context ();
235279
280+ // Provide a type-erased way of attaching arbitrary XLA specific state to the
281+ // GpuExecutor. XLA based execution will use this method to attach per-stream
282+ // executor XLA specific objects (like the Infeed and Outfeed managers) to the
283+ // stream executor, so that their lifetimes can be tied to the lifetime of the
284+ // stream executor for which that object is allocated for. This simplifies
285+ // memory management as compared to having these objects reside on the side
286+ // and then either leaking or having to implement callbacks that the SE
287+ // destructors call to deallocate any side state that is associated with that
288+ // SE object.
289+ template <typename T>
290+ T* getOrCreateXLAState (StreamExecutor* se) {
291+ return xla_state_.getOrCreate <T>(se);
292+ }
293+
236294 private:
237295 // Attempts to find a more specific version of the file indicated by
238296 // filename by looking for compute-capability-specific suffixed versions; i.e.
@@ -337,9 +395,16 @@ class GpuExecutor : public internal::StreamExecutorInterface {
337395 // The plugin configuration associated with this instance.
338396 PluginConfig plugin_config_;
339397
398+ // Type erased XLA specific state attached to GpuExecutor.
399+ Object xla_state_;
400+
340401 SE_DISALLOW_COPY_AND_ASSIGN (GpuExecutor);
341402};
342403
404+ inline GpuExecutor* ExtractGpuExecutor (StreamExecutor* stream_exec) {
405+ return static_cast <GpuExecutor*>(stream_exec->implementation ());
406+ }
407+
343408} // namespace gpu
344409} // namespace stream_executor
345410
0 commit comments