Skip to content

Commit 6ddf5cf

Browse files
desertfirepytorchmergebot
authored andcommitted
[AOTI] Update cpp wrapper codegen to use v2 C shim (#120714)
Summary: To use the torchgen-ed v2 C shim interface, cpp wrapper codegen needs to update its rule for generating the right parameter and function call. Because changing the emitted code will cause a FC breakage, we add a flag to control the behavior. Differential Revision: [D54258086](https://our.internmc.facebook.com/intern/diff/D54258086) Pull Request resolved: #120714 Approved by: https://github.com/chenyang78 ghstack dependencies: #120513
1 parent bd19d6d commit 6ddf5cf

File tree

4 files changed

+45
-8
lines changed

4 files changed

+45
-8
lines changed

torch/_inductor/codegen/cpp_wrapper_cpu.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@ class CppWrapperCpu(WrapperCodeGen):
2424
"""
2525

2626
def __init__(self):
27+
if not hasattr(self, "device"):
28+
self.device = "cpu"
2729
super().__init__()
28-
2930
self.declare = "auto "
3031
self.declare_maybe_reference = "decltype(auto) "
3132
self.ending = ";"
@@ -149,7 +150,12 @@ def write_header(self):
149150
)
150151

151152
if config.abi_compatible:
152-
self.header.splice("#include <torch/csrc/inductor/aoti_torch/c/shim.h>")
153+
if config.c_shim_version == "1":
154+
self.header.splice("#include <torch/csrc/inductor/aoti_torch/c/shim.h>")
155+
else:
156+
self.header.splice(
157+
f"#include <torch/csrc/inductor/aoti_torch/generated/c_shim_{self.device}.h>"
158+
)
153159
else:
154160
if not V.graph.aot_mode:
155161
self.header.splice("#include <pybind11/pybind11.h>")
@@ -924,7 +930,11 @@ def generate_c_shim_extern_kernel_call(self, kernel, args):
924930
kernel_suffix = kernel_tokens[-1]
925931
if kernel_suffix == "call":
926932
kernel_suffix = kernel_tokens[-2]
927-
shim_fn = f"aoti_torch_{kernel_suffix}"
933+
if config.c_shim_version == "1":
934+
shim_fn = f"aoti_torch_{kernel_suffix}"
935+
else:
936+
shim_fn = f"aoti_torch_{self.device}_{kernel_suffix}"
937+
928938
# HACK: val_to_arg_str jams multiple arguments together using a comma. If that
929939
# ever breaks, it needs to be reworked to be able to return multiple arguments,
930940
# and the split-on-comma code here needs to be removed.
@@ -1676,12 +1686,24 @@ def val_to_cpp_arg_str(self, type_, val, is_legacy_abi) -> str:
16761686
):
16771687
if val is None:
16781688
return "0" # nullptr is not available in C
1679-
if isinstance(val, (bool, int, str, float)):
1689+
if not isinstance(type_.getElementType(), torch.TensorType):
16801690
var_name = f"var_{next(self.arg_var_id)}"
16811691
self.writeline(f"auto {var_name} = {self.val_to_arg_str(val)};")
16821692
return f"&{var_name}"
1683-
if not isinstance(type_.getElementType(), torch.TensorType):
1684-
return f"&{self.val_to_arg_str(val)}"
1693+
elif config.c_shim_version == "2":
1694+
# Similar to other data type, use pointer to denote optional tensor arg in v2 C shim
1695+
base_handle = self.val_to_arg_str(val)
1696+
if "wrap_with_raii_handle_if_needed" in base_handle:
1697+
# wrap_with_raii_handle_if_needed creates a temp RAIIAtenTensorHandle, so we need to
1698+
# explicitly store it. Otherwise, it will be destroyed before the fallback kernel call.
1699+
tmp_var_name = f"var_{next(self.arg_var_id)}"
1700+
self.writeline(
1701+
f"RAIIAtenTensorHandle {tmp_var_name} = {base_handle};"
1702+
)
1703+
base_handle = tmp_var_name
1704+
var_name = f"var_{next(self.arg_var_id)}"
1705+
self.writeline(f"AtenTensorHandle {var_name} = {base_handle}.get();")
1706+
return f"&{var_name}"
16851707

16861708
return self.val_to_arg_str(val)
16871709

torch/_inductor/codegen/cpp_wrapper_cuda.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class CppWrapperCuda(CppWrapperCpu):
4343
"""
4444

4545
def __init__(self):
46+
self.device = "cuda"
4647
super().__init__()
4748
self.grid_id = count()
4849
self.cuda = True

torch/_inductor/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ def enable_autotune_remote_cache():
4141
os.environ.get("TORCHINDUCTOR_ABI_COMPATIBLE", "1" if is_fbcode() else "0") == "1"
4242
)
4343

44+
c_shim_version = os.environ.get(
45+
"TORCHINDUCTOR_C_SHIM_VERSION", "1" if is_fbcode() else "2"
46+
)
47+
4448
# dead code elimination
4549
dce = False
4650

torch/_inductor/ir.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4855,7 +4855,10 @@ def is_not_write(arg):
48554855
self.init_args_default_value(kernel._schema)
48564856

48574857
def is_legacy_abi_kernel(self):
4858-
return "_scaled_dot_product_flash_attention" in str(self.python_kernel_name)
4858+
return (
4859+
config.c_shim_version == "1"
4860+
and "_scaled_dot_product_flash_attention" in str(self.python_kernel_name)
4861+
)
48594862

48604863
def init_args_default_value(self, schema):
48614864
self.args_default_value = [
@@ -4908,6 +4911,7 @@ def __repr__(self):
49084911
self.abi_compatible_kernel = (
49094912
f"{self.cpp_kernel_name}_v2"
49104913
if self.cpp_kernel_name in {"at::_scaled_dot_product_flash_attention"}
4914+
and config.c_shim_version == "1"
49114915
else self.cpp_kernel_name
49124916
)
49134917

@@ -5065,7 +5069,13 @@ def codegen(self, wrapper):
50655069
# Aten Fallback Ops
50665070
assert isinstance(kernel, torch._ops.OpOverload)
50675071
if V.graph.cpp_wrapper:
5068-
if config.is_fbcode() and kernel not in has_c_shim:
5072+
if (
5073+
config.is_fbcode()
5074+
and kernel not in has_c_shim
5075+
# C shim v2 is torchgen-ed, which should cover all aten ops.
5076+
# If you do hit a missed op, please update gen_aoti_c_shim.py.
5077+
and config.c_shim_version == "1"
5078+
):
50695079
log.warning(
50705080
"%s is missing a c-shim implementation, using proxy executor as fallback",
50715081
kernel,

0 commit comments

Comments
 (0)