Skip to content

Commit ad8fae2

Browse files
desertfirepytorchmergebot
authored andcommitted
[AOTI] Support test_open_device_registration in ABI-compatible (#136906)
Summary: Add a device type C shim interface to support test_open_device_registration in the ABI-compatible mode. Pull Request resolved: #136906 Approved by: https://github.com/chenyang78
1 parent 8dddd45 commit ad8fae2

File tree

4 files changed

+16
-9
lines changed

4 files changed

+16
-9
lines changed

.ci/pytorch/test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ test_inductor_cpp_wrapper_abi_compatible() {
376376

377377
echo "Testing Inductor cpp wrapper mode with TORCHINDUCTOR_ABI_COMPATIBLE=1"
378378
PYTORCH_TESTING_DEVICE_ONLY_FOR="" python test/run_test.py --include inductor/test_cpu_cpp_wrapper
379-
python test/run_test.py --include inductor/test_cuda_cpp_wrapper inductor/test_cpu_repro
379+
python test/run_test.py --include inductor/test_cuda_cpp_wrapper inductor/test_cpu_repro inductor/test_extension_backend
380380

381381
TORCHINDUCTOR_CPP_WRAPPER=1 python benchmarks/dynamo/timm_models.py --device cuda --accuracy --amp \
382382
--training --inductor --disable-cudagraphs --only vit_base_patch16_224 \

torch/_inductor/codegen/cpp_wrapper_cpu.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1482,8 +1482,12 @@ def generate_inf_and_nan_checker(self, nodes):
14821482

14831483
def codegen_device(self, device):
14841484
if config.abi_compatible:
1485-
self.used_cached_devices.add(device.type)
1486-
return f"cached_torch_device_type_{device.type}, {device.index if device.index else 0}"
1485+
assert device.type in DEVICE_TO_ATEN, (
1486+
device.type + " not found in DEVICE_TO_ATEN"
1487+
)
1488+
device_str = DEVICE_TO_ATEN[device.type][5:].lower() # remove "at::k"
1489+
self.used_cached_devices.add(device_str)
1490+
return f"cached_torch_device_type_{device_str}, {device.index if device.index else 0}"
14871491
else:
14881492
return (
14891493
f"c10::Device({DEVICE_TO_ATEN[device.type]}, {device.index})"

torch/csrc/inductor/aoti_torch/c/shim.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ using AOTITorchError = int32_t;
9696
// desired for perf reasons.)
9797
AOTI_TORCH_EXPORT int32_t aoti_torch_device_type_cpu();
9898
AOTI_TORCH_EXPORT int32_t aoti_torch_device_type_cuda();
99+
AOTI_TORCH_EXPORT int32_t aoti_torch_device_type_privateuse1();
99100

100101
AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float8_e5m2();
101102
AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float8_e4m3fn();

torch/csrc/inductor/aoti_torch/shim_common.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -104,13 +104,15 @@ static c10::Device c10_device(int32_t device_type, int32_t device_index) {
104104

105105
const int AOTI_TORCH_MAX_NUMEL_TO_PRINT = 64;
106106

107-
int32_t aoti_torch_device_type_cpu() {
108-
return (int32_t)c10::DeviceType::CPU;
109-
}
107+
#define AOTI_TORCH_DEVICE_TYPE_IMPL(device_str, device_type) \
108+
int32_t aoti_torch_device_type_##device_str() { \
109+
return (int32_t)c10::DeviceType::device_type; \
110+
}
110111

111-
int32_t aoti_torch_device_type_cuda() {
112-
return (int32_t)c10::DeviceType::CUDA;
113-
}
112+
AOTI_TORCH_DEVICE_TYPE_IMPL(cpu, CPU)
113+
AOTI_TORCH_DEVICE_TYPE_IMPL(cuda, CUDA)
114+
AOTI_TORCH_DEVICE_TYPE_IMPL(privateuse1, PrivateUse1)
115+
#undef AOTI_TORCH_DEVICE_TYPE_IMPL
114116

115117
#define AOTI_TORCH_DTYPE_IMPL(dtype, stype) \
116118
int32_t aoti_torch_dtype_##dtype() { \

0 commit comments

Comments
 (0)