Skip to content

Commit aa988ca

Browse files
committed
Update on "[AOTI] Refine the C shim autogen mechanism"
Summary: Based on the discussions in #120513. Instead of auto-generate C shim fallback ops for thousands of ops, we maintain a list of fallback ops based on torch/_inductor/lowering.py, and only generate C shim functions for those ops. At the torchgen time, we will re-generate C shim files and compare the header file contents against the existing C shim headers. If there is any change, the compilation will fail with prompt on how to proceed. This makes sure the ABI-compatible C shim layer is small enough to maintain in the long run. [ghstack-poisoned]
1 parent 49b81b2 commit aa988ca

File tree

3 files changed

+33
-52
lines changed

3 files changed

+33
-52
lines changed

build_variables.bzl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,7 @@ lazy_tensor_core_python_sources = [
468468
inductor_core_resources = [
469469
"torch/csrc/inductor/aoti_runner/model_container_runner.cpp",
470470
"torch/csrc/inductor/aoti_runner/model_container_runner_cpu.cpp",
471+
"torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.cpp",
471472
"torch/csrc/inductor/aoti_torch/shim_common.cpp",
472473
"torch/csrc/inductor/aoti_torch/tensor_converter.cpp",
473474
"torch/csrc/inductor/inductor_ops.cpp",
@@ -656,6 +657,7 @@ libtorch_cuda_core_sources = [
656657
"torch/csrc/cuda/comm.cpp",
657658
"torch/csrc/cuda/memory_snapshot.cpp",
658659
"torch/csrc/inductor/aoti_runner/model_container_runner_cuda.cpp",
660+
"torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.cpp",
659661
"torch/csrc/inductor/aoti_torch/shim_cuda.cpp",
660662
"torch/csrc/jit/codegen/fuser/cuda/fused_kernel.cpp",
661663
"torch/csrc/profiler/stubs/cuda.cpp",

caffe2/CMakeLists.txt

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,6 @@ if(NOT INTERN_DISABLE_AUTOGRAD AND NOT BUILD_LITE_INTERPRETER)
352352
"${TORCH_SRC_DIR}/csrc/autograd/generated/TraceType_4.cpp"
353353
"${TORCH_SRC_DIR}/csrc/autograd/generated/ADInplaceOrViewType_0.cpp"
354354
"${TORCH_SRC_DIR}/csrc/autograd/generated/ADInplaceOrViewType_1.cpp"
355-
"${TORCH_SRC_DIR}/csrc/inductor/aoti_torch/generated/c_shim_cpu.cpp"
356355
)
357356
if(BUILD_LAZY_TS_BACKEND)
358357
list(APPEND GENERATED_CXX_TORCH
@@ -407,17 +406,12 @@ set(GENERATED_TESTING_PYTHON
407406
"${TORCH_SRC_DIR}/testing/_internal/generated/annotated_fn_args.py"
408407
)
409408

410-
set(GENERATED_CXX_TORCH_CUDA
411-
"${TORCH_SRC_DIR}/csrc/inductor/aoti_torch/generated/c_shim_cuda.cpp"
412-
)
413-
414409
set(TORCH_GENERATED_CODE
415410
${GENERATED_CXX_TORCH}
416411
${GENERATED_H_TORCH}
417412
${GENERATED_CXX_PYTHON}
418413
${GENERATED_H_PYTHON}
419414
${GENERATED_TESTING_PYTHON}
420-
${GENERATED_CXX_TORCH_CUDA}
421415
)
422416

423417
set(GEN_PER_OPERATOR_FLAG)
@@ -966,7 +960,7 @@ endif()
966960
# Compile exposed libraries.
967961
if(USE_ROCM)
968962
set(CUDA_LINK_LIBRARIES_KEYWORD PRIVATE)
969-
list(APPEND Caffe2_HIP_SRCS ${GENERATED_CXX_TORCH_CUDA})
963+
list(APPEND Caffe2_HIP_SRCS)
970964
hip_add_library(torch_hip ${Caffe2_HIP_SRCS})
971965
if(USE_FLASH_ATTENTION)
972966
target_link_libraries(torch_hip PRIVATE __caffe2_aotriton)
@@ -986,7 +980,7 @@ if(USE_ROCM)
986980
endif()
987981
elseif(USE_CUDA)
988982
set(CUDA_LINK_LIBRARIES_KEYWORD PRIVATE)
989-
list(APPEND Caffe2_GPU_SRCS ${GENERATED_CXX_TORCH_CUDA})
983+
list(APPEND Caffe2_GPU_SRCS)
990984
if(CUDA_SEPARABLE_COMPILATION)
991985
# Separate compilation fails when kernels using `thrust::sort_by_key`
992986
# are linked with the rest of CUDA code. Workaround by linking them separately.

torchgen/gen.py

Lines changed: 29 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import json
44
import os
55
import pathlib
6-
import shutil
76

87
from collections import defaultdict, namedtuple, OrderedDict
98
from dataclasses import dataclass, field
@@ -2401,43 +2400,35 @@ def headers_for_aoti() -> str:
24012400
existing_c_shim_path = "torch/csrc/inductor/aoti_torch/generated"
24022401
header_file_name = f"c_shim_{dispatch_key.lower()}.h"
24032402
cpp_file_name = f"c_shim_{dispatch_key.lower()}.cpp"
2404-
aoti_fm.write(
2405-
header_file_name,
2406-
lambda: gen_aoti_c_shim(
2407-
fallback_native_functions,
2408-
dispatch_key,
2409-
backend_indices,
2410-
header=True,
2411-
includes="",
2412-
),
2403+
new_header = gen_aoti_c_shim(
2404+
fallback_native_functions,
2405+
dispatch_key,
2406+
backend_indices,
2407+
header=True,
2408+
includes="",
24132409
)
2414-
aoti_fm.write(
2415-
cpp_file_name,
2416-
lambda: gen_aoti_c_shim(
2417-
fallback_native_functions,
2418-
dispatch_key,
2419-
backend_indices,
2420-
header=False,
2421-
includes=headers_for_aoti() + "\n" + extra_headers,
2422-
),
2410+
new_cpp = gen_aoti_c_shim(
2411+
fallback_native_functions,
2412+
dispatch_key,
2413+
backend_indices,
2414+
header=False,
2415+
includes=headers_for_aoti() + "\n" + extra_headers,
24232416
)
2417+
24242418
if update_aoti_c_shim:
2425-
shutil.copy2(
2426-
os.path.join(aoti_fm.install_dir, header_file_name),
2427-
os.path.join(existing_c_shim_path, header_file_name),
2419+
aoti_fm.write(
2420+
header_file_name,
2421+
lambda: new_header,
24282422
)
2429-
shutil.copy2(
2430-
os.path.join(aoti_fm.install_dir, cpp_file_name),
2431-
os.path.join(existing_c_shim_path, cpp_file_name),
2423+
aoti_fm.write(
2424+
cpp_file_name,
2425+
lambda: new_cpp,
24322426
)
24332427
else:
24342428
with open(
24352429
os.path.join(existing_c_shim_path, header_file_name)
2436-
) as old_file, open(
2437-
os.path.join(aoti_fm.install_dir, header_file_name)
2438-
) as new_file:
2430+
) as old_file:
24392431
old_header = old_file.read()
2440-
new_header = new_file.read()
24412432
assert (
24422433
old_header == new_header
24432434
), """
@@ -2765,18 +2756,6 @@ def main() -> None:
27652756
help="output directory",
27662757
default="build/aten/src/ATen",
27672758
)
2768-
parser.add_argument(
2769-
"--aoti-install-dir",
2770-
"--aoti_install_dir",
2771-
help="output directory for AOTInductor shim",
2772-
default="build/aoti/generated",
2773-
)
2774-
parser.add_argument(
2775-
"--update-aoti-c-shim",
2776-
action="store_true",
2777-
help="Update AOTInductor C shim after changing torchgen/aoti/fallback_ops.py. "
2778-
"WARNING: Do not use this unless you are sure what you are doing!!!",
2779-
)
27802759
parser.add_argument(
27812760
"--rocm",
27822761
action="store_true",
@@ -2841,6 +2820,12 @@ def main() -> None:
28412820
default=["headers", "sources", "declarations_yaml"],
28422821
help="Generate only a subset of files",
28432822
)
2823+
parser.add_argument(
2824+
"--update-aoti-c-shim",
2825+
action="store_true",
2826+
help="Update AOTInductor C shim after changing torchgen/aoti/fallback_ops.py. "
2827+
"WARNING: Do not use this unless you are sure what you are doing!!!",
2828+
)
28442829

28452830
options = parser.parse_args()
28462831

@@ -2897,15 +2882,15 @@ def main() -> None:
28972882
pathlib.Path(core_install_dir).mkdir(parents=True, exist_ok=True)
28982883
ops_install_dir = f"{options.install_dir}/ops"
28992884
pathlib.Path(ops_install_dir).mkdir(parents=True, exist_ok=True)
2900-
aoti_install_dir = f"{options.aoti_install_dir}"
2901-
pathlib.Path(aoti_install_dir).mkdir(parents=True, exist_ok=True)
29022885

29032886
core_fm = make_file_manager(options=options, install_dir=core_install_dir)
29042887
cpu_fm = make_file_manager(options=options)
29052888
cpu_vec_fm = make_file_manager(options=options)
29062889
cuda_fm = make_file_manager(options=options)
29072890
ops_fm = make_file_manager(options=options, install_dir=ops_install_dir)
2908-
aoti_fm = make_file_manager(options=options, install_dir=aoti_install_dir)
2891+
aoti_fm = make_file_manager(
2892+
options=options, install_dir="torch/csrc/inductor/aoti_torch/generated"
2893+
)
29092894

29102895
# Only a limited set of dispatch keys get CPUFunctions.h headers generated
29112896
# for them; this is the set

0 commit comments

Comments
 (0)