Skip to content

Commit 40ec155

Browse files
desertfirepytorchmergebot
authored andcommitted
[AOTI][refactor] Split common aoti_runtime utils into a separate header (#119066)
Summary: Split common utils from aoti_runtime/model.h into a separate header file, because when turning on ABI-compatible mode for JIT Inductor we won't need AOTInductorModel, but we do need some common utils, e.g. RAIIAtenTensorHandle. Differential Revision: [D53478809](https://our.internmc.facebook.com/intern/diff/D53478809) Pull Request resolved: #119066 Approved by: https://github.com/khabinov
1 parent 059994d commit 40ec155

File tree

8 files changed

+225
-205
lines changed

8 files changed

+225
-205
lines changed

torch/_inductor/codegen/aoti_runtime/interface.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -341,8 +341,4 @@ AOTIRuntimeError AOTInductorModelUpdateConstantsMap(
341341
})
342342
}
343343

344-
#define CACHE_TORCH_DTYPE(typename) static auto cached_torch_dtype_##typename = aoti_torch_dtype_##typename()
345-
346-
static auto cached_torch_device_type_cpu = aoti_torch_device_type_cpu();
347-
static auto cached_torch_device_type_cuda = aoti_torch_device_type_cuda();
348344
} // extern "C"

torch/_inductor/codegen/wrapper.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1421,6 +1421,7 @@ def __init__(self):
14211421
self.declared_int_array_vars = set()
14221422
self.tmp_tensor_id = count() # for tmp tensor local variable declarations
14231423
self.arg_var_id = count()
1424+
self.used_cached_devices = set()
14241425
self.used_cached_dtypes = set()
14251426
self.cached_output_id = count()
14261427
self.scalar_to_tensor_id = count()
@@ -2047,6 +2048,8 @@ def finalize_prefix(self):
20472048
if config.abi_compatible:
20482049
for dtype in self.used_cached_dtypes:
20492050
cached_dtypes_buffer.writeline(f"CACHE_TORCH_DTYPE({dtype});")
2051+
for device in self.used_cached_devices:
2052+
cached_dtypes_buffer.writeline(f"CACHE_TORCH_DEVICE({device});")
20502053
cached_dtypes_buffer.splice(self.prefix)
20512054
self.prefix = cached_dtypes_buffer
20522055

@@ -2521,6 +2524,7 @@ def generate_inf_and_nan_checker(self, nodes):
25212524

25222525
def codegen_device(self, device):
25232526
if config.abi_compatible:
2527+
self.used_cached_devices.add(device.type)
25242528
return f"cached_torch_device_type_{device.type},{device.index if device.index else 0}"
25252529
else:
25262530
from .cpp import DEVICE_TO_ATEN
@@ -3078,7 +3082,11 @@ def write_header(self):
30783082
super().write_header()
30793083

30803084
self.header.splice("#include <filesystem>")
3081-
if not config.abi_compatible:
3085+
if config.abi_compatible:
3086+
self.header.splice(
3087+
"#include <torch/csrc/inductor/aoti_runtime/utils_cuda.h>"
3088+
)
3089+
else:
30823090
self.header.splice(
30833091
"""
30843092
#include <c10/cuda/CUDAGuard.h>

torch/csrc/inductor/aoti_runtime/arrayref_tensor.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#pragma once
22

3-
#include <torch/csrc/inductor/aoti_runtime/model.h>
4-
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
3+
#include <torch/csrc/inductor/aoti_runtime/utils.h>
54

65
#include <assert.h>
76
#include <cstdint>

torch/csrc/inductor/aoti_runtime/interface.h

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,10 @@
11
#pragma once
22

3-
#include <stddef.h>
4-
#include <stdint.h>
5-
63
// WARNING: Be careful when adding new includes here. This header will be used
74
// in model.so, and should not refer to any aten/c10 headers except the stable
85
// C ABI defined in torch/csrc/inductor/aoti_torch/c/shim.h. The same rule
96
// applies to other files under torch/csrc/inductor/aoti_runtime/.
10-
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
11-
12-
#ifdef __GNUC__
13-
#define AOT_INDUCTOR_EXPORT __attribute__((__visibility__("default")))
14-
#else // !__GNUC__
15-
#ifdef _WIN32
16-
#define AOT_INDUCTOR_EXPORT __declspec(dllexport)
17-
#else // !_WIN32
18-
#define AOT_INDUCTOR_EXPORT
19-
#endif // _WIN32
20-
#endif // __GNUC__
21-
22-
using AOTIRuntimeError = int32_t;
23-
#define AOTI_RUNTIME_SUCCESS 0
24-
#define AOTI_RUNTIME_FAILURE 1
25-
26-
#define AOTI_RUNTIME_ERROR_CODE_CHECK(call) \
27-
if ((call) != AOTI_RUNTIME_SUCCESS) { \
28-
throw std::runtime_error( \
29-
std::string(#call " API call failed at ") + __FILE__ + ", line " + \
30-
std::to_string(__LINE__)); \
31-
}
7+
#include <torch/csrc/inductor/aoti_runtime/utils.h>
328

339
extern "C" {
3410
struct AOTInductorModelOpaque;

torch/csrc/inductor/aoti_runtime/model.h

Lines changed: 1 addition & 171 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,15 @@
11
#pragma once
22

3-
#include <functional>
4-
#include <iostream>
5-
#include <memory>
63
#include <optional>
74
#include <regex>
8-
#include <sstream>
9-
#include <stdexcept>
10-
#include <string>
115
#include <unordered_map>
12-
#include <vector>
136

147
// WARNING: Be careful when adding new includes here. This header will be used
158
// in model.so, and should not refer to any aten/c10 headers except the stable
169
// C ABI defined in torch/csrc/inductor/aoti_torch/c/shim.h. The same rule
1710
// applies to other files under torch/csrc/inductor/aoti_runtime/.
1811
#include <torch/csrc/inductor/aoti_runtime/device_utils.h>
19-
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
12+
#include <torch/csrc/inductor/aoti_runtime/utils.h>
2013

2114
#define AOTI_RUNTIME_CHECK(EXPR, MSG) \
2215
do { \
@@ -26,14 +19,6 @@
2619
} \
2720
} while (0)
2821

29-
#if defined(__GNUC__) || defined(__clang__)
30-
#define AOTI_NOINLINE __attribute__((noinline))
31-
#elif _MSC_VER
32-
#define AOTI_NOINLINE __declspec(noinline)
33-
#else
34-
#define AOTI_NOINLINE
35-
#endif
36-
3722
// At codegen time, we write out a binary file called constants.bin.
3823
// We then turn the raw binary to an object file that exposes this
3924
// symbol and link it into the final .so.
@@ -63,146 +48,10 @@ CUDAPtr RAII_cudaMalloc(size_t num_bytes) {
6348

6449
} // anonymous namespace
6550

66-
AOTI_NOINLINE static void throw_exception(
67-
const char* call,
68-
const char* file,
69-
int64_t line) {
70-
std::stringstream ss;
71-
ss << call << " API call failed at " << file << ", line " << line;
72-
throw std::runtime_error(ss.str());
73-
}
74-
75-
#define AOTI_TORCH_ERROR_CODE_CHECK(call) \
76-
if ((call) != AOTI_TORCH_SUCCESS) { \
77-
throw_exception(#call, __FILE__, __LINE__); \
78-
}
79-
80-
using DeleterFnPtr = void (*)(void*);
81-
8251
namespace torch {
8352
namespace aot_inductor {
84-
85-
inline void noop_deleter(void*) {}
86-
87-
inline void delete_tensor_object(void* ptr) {
88-
AOTI_TORCH_ERROR_CODE_CHECK(
89-
aoti_torch_delete_tensor_object(reinterpret_cast<AtenTensorHandle>(ptr)));
90-
}
91-
92-
// RAIIAtenTensorHandle steals the tensor objects created by the libtorch C ABI
93-
class RAIIAtenTensorHandle {
94-
public:
95-
RAIIAtenTensorHandle() : handle_(nullptr, noop_deleter) {}
96-
RAIIAtenTensorHandle(const RAIIAtenTensorHandle& other) = delete;
97-
RAIIAtenTensorHandle& operator=(const RAIIAtenTensorHandle& other) = delete;
98-
99-
// Steal the ownership from another RAIIAtenTensorHandle using std::move
100-
RAIIAtenTensorHandle(RAIIAtenTensorHandle&& other) = default;
101-
RAIIAtenTensorHandle& operator=(RAIIAtenTensorHandle&& other) = default;
102-
103-
// Steal the ownership from raw AtenTensorHandle
104-
RAIIAtenTensorHandle(AtenTensorHandle handle)
105-
: handle_(handle, delete_tensor_object) {}
106-
107-
~RAIIAtenTensorHandle() {
108-
handle_.reset();
109-
}
110-
111-
// Return a raw AtenTensorHandle to be used by aoti_torch functions
112-
// Note: this function does NOT transfer the ownership of the handle
113-
operator AtenTensorHandle() const {
114-
return handle_.get();
115-
}
116-
117-
AtenTensorHandle release() {
118-
return handle_.release();
119-
}
120-
121-
AtenTensorHandle get() const {
122-
return handle_.get();
123-
}
124-
125-
void reset() {
126-
handle_.reset();
127-
}
128-
129-
int64_t size(int64_t d) {
130-
int64_t size;
131-
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_size(handle_.get(), d, &size));
132-
return size;
133-
}
134-
135-
int64_t stride(int64_t d) {
136-
int64_t stride;
137-
AOTI_TORCH_ERROR_CODE_CHECK(
138-
aoti_torch_get_stride(handle_.get(), d, &stride));
139-
return stride;
140-
}
141-
142-
int64_t storage_offset() {
143-
int64_t storage_offset;
144-
AOTI_TORCH_ERROR_CODE_CHECK(
145-
aoti_torch_get_storage_offset(handle_.get(), &storage_offset));
146-
return storage_offset;
147-
}
148-
149-
private:
150-
std::unique_ptr<AtenTensorOpaque, DeleterFnPtr> handle_;
151-
};
152-
15353
using ConstantMap = std::unordered_map<std::string, RAIIAtenTensorHandle>;
15454

155-
class ConstantHandle {
156-
public:
157-
ConstantHandle() = default;
158-
159-
explicit ConstantHandle(AtenTensorHandle handle) : handle_(handle) {
160-
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(handle_, &data_));
161-
}
162-
163-
operator AtenTensorHandle() const {
164-
return handle_;
165-
}
166-
167-
AtenTensorHandle tensor() const {
168-
return handle_;
169-
}
170-
171-
void* data_ptr() const {
172-
return data_;
173-
}
174-
175-
private:
176-
AtenTensorHandle handle_;
177-
void* data_ = nullptr;
178-
};
179-
180-
inline void* get_data_ptr_wrapper(const ConstantHandle& constant) {
181-
return constant.data_ptr();
182-
}
183-
184-
inline const ConstantHandle& unwrap_raii_handle_if_needed(
185-
const ConstantHandle& handle) {
186-
return handle;
187-
}
188-
189-
// Shouldn't be called.
190-
inline AtenTensorHandle wrap_with_raii_handle_if_needed(
191-
const ConstantHandle& handle) = delete;
192-
193-
// Steal the ownership from raw AtenTensorHandle to RAIIAtenTensorHandle
194-
inline std::vector<RAIIAtenTensorHandle> steal_from_raw_handles_to_raii_handles(
195-
AtenTensorHandle* handles,
196-
size_t size) {
197-
std::vector<RAIIAtenTensorHandle> result;
198-
result.reserve(size);
199-
for (size_t i = 0; i < size; i++) {
200-
result.emplace_back(handles[i]);
201-
handles[i] = nullptr;
202-
}
203-
return result;
204-
}
205-
20655
// valid device strs are: cpu, cuda, cuda:0, cuda:1, ...
20756
// Update the list here if more devices are supported in the future
20857
inline void parse_device_str(
@@ -644,24 +493,5 @@ class AOTInductorModel : public AOTInductorModelBase<AOTInductorModel> {
644493
std::unique_ptr<AOTInductorModelKernelsBase> kernels_;
645494
};
646495

647-
#ifdef USE_CUDA
648-
class AOTICudaStreamGuard {
649-
public:
650-
AOTICudaStreamGuard(cudaStream_t stream, int32_t device_index) {
651-
CUDAStreamGuardHandle ptr;
652-
AOTI_TORCH_ERROR_CODE_CHECK(
653-
aoti_torch_create_cuda_stream_guard(stream, device_index, &ptr));
654-
guard_ =
655-
std::unique_ptr<void, std::function<void(void*)>>(ptr, [](void* ptr) {
656-
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_delete_cuda_stream_guard(
657-
reinterpret_cast<CUDAStreamGuardHandle>(ptr)));
658-
});
659-
}
660-
661-
private:
662-
std::unique_ptr<void, std::function<void(void*)>> guard_;
663-
};
664-
#endif // USE_CUDA
665-
666496
} // namespace aot_inductor
667497
} // namespace torch

torch/csrc/inductor/aoti_runtime/scalar_to_tensor.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#pragma once
22

3-
#include <torch/csrc/inductor/aoti_runtime/model.h>
4-
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
3+
#include <torch/csrc/inductor/aoti_runtime/utils.h>
54

65
namespace torch {
76
namespace aot_inductor {

0 commit comments

Comments
 (0)