|
1 | 1 | #pragma once |
2 | 2 |
|
3 | | -#include <functional> |
4 | | -#include <iostream> |
5 | | -#include <memory> |
6 | 3 | #include <optional> |
7 | 4 | #include <regex> |
8 | | -#include <sstream> |
9 | | -#include <stdexcept> |
10 | | -#include <string> |
11 | 5 | #include <unordered_map> |
12 | | -#include <vector> |
13 | 6 |
|
14 | 7 | // WARNING: Be careful when adding new includes here. This header will be used |
15 | 8 | // in model.so, and should not refer to any aten/c10 headers except the stable |
16 | 9 | // C ABI defined in torch/csrc/inductor/aoti_torch/c/shim.h. The same rule |
17 | 10 | // applies to other files under torch/csrc/inductor/aoti_runtime/. |
18 | 11 | #include <torch/csrc/inductor/aoti_runtime/device_utils.h> |
19 | | -#include <torch/csrc/inductor/aoti_torch/c/shim.h> |
| 12 | +#include <torch/csrc/inductor/aoti_runtime/utils.h> |
20 | 13 |
|
21 | 14 | #define AOTI_RUNTIME_CHECK(EXPR, MSG) \ |
22 | 15 | do { \ |
|
26 | 19 | } \ |
27 | 20 | } while (0) |
28 | 21 |
|
29 | | -#if defined(__GNUC__) || defined(__clang__) |
30 | | -#define AOTI_NOINLINE __attribute__((noinline)) |
31 | | -#elif _MSC_VER |
32 | | -#define AOTI_NOINLINE __declspec(noinline) |
33 | | -#else |
34 | | -#define AOTI_NOINLINE |
35 | | -#endif |
36 | | - |
37 | 22 | // At codegen time, we write out a binary file called constants.bin. |
38 | 23 | // We then turn the raw binary to an object file that exposes this |
39 | 24 | // symbol and link it into the final .so. |
@@ -63,146 +48,10 @@ CUDAPtr RAII_cudaMalloc(size_t num_bytes) { |
63 | 48 |
|
64 | 49 | } // anonymous namespace |
65 | 50 |
|
66 | | -AOTI_NOINLINE static void throw_exception( |
67 | | - const char* call, |
68 | | - const char* file, |
69 | | - int64_t line) { |
70 | | - std::stringstream ss; |
71 | | - ss << call << " API call failed at " << file << ", line " << line; |
72 | | - throw std::runtime_error(ss.str()); |
73 | | -} |
74 | | - |
75 | | -#define AOTI_TORCH_ERROR_CODE_CHECK(call) \ |
76 | | - if ((call) != AOTI_TORCH_SUCCESS) { \ |
77 | | - throw_exception(#call, __FILE__, __LINE__); \ |
78 | | - } |
79 | | - |
80 | | -using DeleterFnPtr = void (*)(void*); |
81 | | - |
82 | 51 | namespace torch { |
83 | 52 | namespace aot_inductor { |
84 | | - |
85 | | -inline void noop_deleter(void*) {} |
86 | | - |
87 | | -inline void delete_tensor_object(void* ptr) { |
88 | | - AOTI_TORCH_ERROR_CODE_CHECK( |
89 | | - aoti_torch_delete_tensor_object(reinterpret_cast<AtenTensorHandle>(ptr))); |
90 | | -} |
91 | | - |
92 | | -// RAIIAtenTensorHandle steals the tensor objects created by the libtorch C ABI |
93 | | -class RAIIAtenTensorHandle { |
94 | | - public: |
95 | | - RAIIAtenTensorHandle() : handle_(nullptr, noop_deleter) {} |
96 | | - RAIIAtenTensorHandle(const RAIIAtenTensorHandle& other) = delete; |
97 | | - RAIIAtenTensorHandle& operator=(const RAIIAtenTensorHandle& other) = delete; |
98 | | - |
99 | | - // Steal the ownership from another RAIIAtenTensorHandle using std::move |
100 | | - RAIIAtenTensorHandle(RAIIAtenTensorHandle&& other) = default; |
101 | | - RAIIAtenTensorHandle& operator=(RAIIAtenTensorHandle&& other) = default; |
102 | | - |
103 | | - // Steal the ownership from raw AtenTensorHandle |
104 | | - RAIIAtenTensorHandle(AtenTensorHandle handle) |
105 | | - : handle_(handle, delete_tensor_object) {} |
106 | | - |
107 | | - ~RAIIAtenTensorHandle() { |
108 | | - handle_.reset(); |
109 | | - } |
110 | | - |
111 | | - // Return a raw AtenTensorHandle to be used by aoti_torch functions |
112 | | - // Note: this function does NOT transfer the ownership of the handle |
113 | | - operator AtenTensorHandle() const { |
114 | | - return handle_.get(); |
115 | | - } |
116 | | - |
117 | | - AtenTensorHandle release() { |
118 | | - return handle_.release(); |
119 | | - } |
120 | | - |
121 | | - AtenTensorHandle get() const { |
122 | | - return handle_.get(); |
123 | | - } |
124 | | - |
125 | | - void reset() { |
126 | | - handle_.reset(); |
127 | | - } |
128 | | - |
129 | | - int64_t size(int64_t d) { |
130 | | - int64_t size; |
131 | | - AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_size(handle_.get(), d, &size)); |
132 | | - return size; |
133 | | - } |
134 | | - |
135 | | - int64_t stride(int64_t d) { |
136 | | - int64_t stride; |
137 | | - AOTI_TORCH_ERROR_CODE_CHECK( |
138 | | - aoti_torch_get_stride(handle_.get(), d, &stride)); |
139 | | - return stride; |
140 | | - } |
141 | | - |
142 | | - int64_t storage_offset() { |
143 | | - int64_t storage_offset; |
144 | | - AOTI_TORCH_ERROR_CODE_CHECK( |
145 | | - aoti_torch_get_storage_offset(handle_.get(), &storage_offset)); |
146 | | - return storage_offset; |
147 | | - } |
148 | | - |
149 | | - private: |
150 | | - std::unique_ptr<AtenTensorOpaque, DeleterFnPtr> handle_; |
151 | | -}; |
152 | | - |
153 | 53 | using ConstantMap = std::unordered_map<std::string, RAIIAtenTensorHandle>; |
154 | 54 |
|
155 | | -class ConstantHandle { |
156 | | - public: |
157 | | - ConstantHandle() = default; |
158 | | - |
159 | | - explicit ConstantHandle(AtenTensorHandle handle) : handle_(handle) { |
160 | | - AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(handle_, &data_)); |
161 | | - } |
162 | | - |
163 | | - operator AtenTensorHandle() const { |
164 | | - return handle_; |
165 | | - } |
166 | | - |
167 | | - AtenTensorHandle tensor() const { |
168 | | - return handle_; |
169 | | - } |
170 | | - |
171 | | - void* data_ptr() const { |
172 | | - return data_; |
173 | | - } |
174 | | - |
175 | | - private: |
176 | | - AtenTensorHandle handle_; |
177 | | - void* data_ = nullptr; |
178 | | -}; |
179 | | - |
180 | | -inline void* get_data_ptr_wrapper(const ConstantHandle& constant) { |
181 | | - return constant.data_ptr(); |
182 | | -} |
183 | | - |
184 | | -inline const ConstantHandle& unwrap_raii_handle_if_needed( |
185 | | - const ConstantHandle& handle) { |
186 | | - return handle; |
187 | | -} |
188 | | - |
189 | | -// Shouldn't be called. |
190 | | -inline AtenTensorHandle wrap_with_raii_handle_if_needed( |
191 | | - const ConstantHandle& handle) = delete; |
192 | | - |
193 | | -// Steal the ownership from raw AtenTensorHandle to RAIIAtenTensorHandle |
194 | | -inline std::vector<RAIIAtenTensorHandle> steal_from_raw_handles_to_raii_handles( |
195 | | - AtenTensorHandle* handles, |
196 | | - size_t size) { |
197 | | - std::vector<RAIIAtenTensorHandle> result; |
198 | | - result.reserve(size); |
199 | | - for (size_t i = 0; i < size; i++) { |
200 | | - result.emplace_back(handles[i]); |
201 | | - handles[i] = nullptr; |
202 | | - } |
203 | | - return result; |
204 | | -} |
205 | | - |
206 | 55 | // valid device strs are: cpu, cuda, cuda:0, cuda:1, ... |
207 | 56 | // Update the list here if more devices are supported in the future |
208 | 57 | inline void parse_device_str( |
@@ -644,24 +493,5 @@ class AOTInductorModel : public AOTInductorModelBase<AOTInductorModel> { |
644 | 493 | std::unique_ptr<AOTInductorModelKernelsBase> kernels_; |
645 | 494 | }; |
646 | 495 |
|
647 | | -#ifdef USE_CUDA |
648 | | -class AOTICudaStreamGuard { |
649 | | - public: |
650 | | - AOTICudaStreamGuard(cudaStream_t stream, int32_t device_index) { |
651 | | - CUDAStreamGuardHandle ptr; |
652 | | - AOTI_TORCH_ERROR_CODE_CHECK( |
653 | | - aoti_torch_create_cuda_stream_guard(stream, device_index, &ptr)); |
654 | | - guard_ = |
655 | | - std::unique_ptr<void, std::function<void(void*)>>(ptr, [](void* ptr) { |
656 | | - AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_delete_cuda_stream_guard( |
657 | | - reinterpret_cast<CUDAStreamGuardHandle>(ptr))); |
658 | | - }); |
659 | | - } |
660 | | - |
661 | | - private: |
662 | | - std::unique_ptr<void, std::function<void(void*)>> guard_; |
663 | | -}; |
664 | | -#endif // USE_CUDA |
665 | | - |
666 | 496 | } // namespace aot_inductor |
667 | 497 | } // namespace torch |
0 commit comments