Skip to content

Commit 29b51ee

Browse files
committed
[AOTI] Update cpp wrapper codegen to use v2 C shim
Summary: To use the torchgen-ed v2 C shim interface, cpp wrapper codegen needs to update its rule for generating the right parameter and function call. Because changing the emitted code will cause a FC breakage, we add a flag to control the behavior. ghstack-source-id: 0bcec4e Pull Request resolved: #120714
1 parent 10c096b commit 29b51ee

File tree

4 files changed

+38
-8
lines changed

4 files changed

+38
-8
lines changed

torch/_inductor/codegen/cpp_wrapper_cpu.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@ class CppWrapperCpu(WrapperCodeGen):
2424
"""
2525

2626
def __init__(self):
27+
if not hasattr(self, "device"):
28+
self.device = "cpu"
2729
super().__init__()
28-
2930
self.declare = "auto "
3031
self.declare_maybe_reference = "decltype(auto) "
3132
self.ending = ";"
@@ -148,7 +149,12 @@ def write_header(self):
148149
)
149150

150151
if config.abi_compatible:
151-
self.header.splice("#include <torch/csrc/inductor/aoti_torch/c/shim.h>")
152+
if config.c_shim_version == "1":
153+
self.header.splice("#include <torch/csrc/inductor/aoti_torch/c/shim.h>")
154+
else:
155+
self.header.splice(
156+
f"#include <torch/csrc/inductor/aoti_torch/generated/c_shim_{self.device}.h>"
157+
)
152158
else:
153159
if not V.graph.aot_mode:
154160
self.header.splice("#include <pybind11/pybind11.h>")
@@ -915,7 +921,11 @@ def generate_c_shim_extern_kernel_call(self, kernel, args):
915921
kernel_suffix = kernel_tokens[-1]
916922
if kernel_suffix == "call":
917923
kernel_suffix = kernel_tokens[-2]
918-
shim_fn = f"aoti_torch_{kernel_suffix}"
924+
if config.c_shim_version == "1":
925+
shim_fn = f"aoti_torch_{kernel_suffix}"
926+
else:
927+
shim_fn = f"aoti_torch_{self.device}_{kernel_suffix}"
928+
919929
# HACK: val_to_arg_str jams multiple arguments together using a comma. If that
920930
# ever breaks, it needs to be reworked to be able to return multiple arguments,
921931
# and the split-on-comma code here needs to be removed.
@@ -1664,12 +1674,17 @@ def val_to_cpp_arg_str(self, type_, val, is_legacy_abi) -> str:
16641674
):
16651675
if val is None:
16661676
return "0" # nullptr is not available in C
1667-
if isinstance(val, (bool, int, str, float)):
1677+
if not isinstance(type_.getElementType(), torch.TensorType):
16681678
var_name = f"var_{next(self.arg_var_id)}"
16691679
self.writeline(f"auto {var_name} = {self.val_to_arg_str(val)};")
16701680
return f"&{var_name}"
1671-
if not isinstance(type_.getElementType(), torch.TensorType):
1672-
return f"&{self.val_to_arg_str(val)}"
1681+
elif config.c_shim_version == "2":
1682+
# Similar to other data type, use pointer to denote optional tensor arg in v2 C shim
1683+
var_name = f"var_{next(self.arg_var_id)}"
1684+
self.writeline(
1685+
f"AtenTensorHandle {var_name} = {self.val_to_arg_str(val)}.get();"
1686+
)
1687+
return f"&{var_name}"
16731688

16741689
return self.val_to_arg_str(val)
16751690

torch/_inductor/codegen/cpp_wrapper_cuda.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class CppWrapperCuda(CppWrapperCpu):
4343
"""
4444

4545
def __init__(self):
46+
self.device = "cuda"
4647
super().__init__()
4748
self.grid_id = count()
4849
self.cuda = True

torch/_inductor/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ def is_fbcode():
3232
os.environ.get("TORCHINDUCTOR_ABI_COMPATIBLE", "1" if is_fbcode() else "0") == "1"
3333
)
3434

35+
c_shim_version = os.environ.get(
36+
"TORCHINDUCTOR_C_SHIM_VERSION", "1" if is_fbcode() else "2"
37+
)
38+
3539
# dead code elimination
3640
dce = False
3741

torch/_inductor/ir.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4812,7 +4812,10 @@ def is_not_write(arg):
48124812
self.init_args_default_value(kernel._schema)
48134813

48144814
def is_legacy_abi_kernel(self):
4815-
return "_scaled_dot_product_flash_attention" in str(self.python_kernel_name)
4815+
return (
4816+
config.c_shim_version == "1"
4817+
and "_scaled_dot_product_flash_attention" in str(self.python_kernel_name)
4818+
)
48164819

48174820
def init_args_default_value(self, schema):
48184821
self.args_default_value = [
@@ -4865,6 +4868,7 @@ def __repr__(self):
48654868
self.abi_compatible_kernel = (
48664869
f"{self.cpp_kernel_name}_v2"
48674870
if self.cpp_kernel_name in {"at::_scaled_dot_product_flash_attention"}
4871+
and config.c_shim_version == "1"
48684872
else self.cpp_kernel_name
48694873
)
48704874

@@ -5022,7 +5026,13 @@ def codegen(self, wrapper):
50225026
# Aten Fallback Ops
50235027
assert isinstance(kernel, torch._ops.OpOverload)
50245028
if V.graph.cpp_wrapper:
5025-
if config.is_fbcode() and kernel not in has_c_shim:
5029+
if (
5030+
config.is_fbcode()
5031+
and kernel not in has_c_shim
5032+
# C shim v2 is torchgen-ed, which should cover all aten ops.
5033+
# If you do hit a missed op, please update gen_aoti_c_shim.py.
5034+
and config.c_shim_version == "1"
5035+
):
50265036
log.warning(
50275037
"%s is missing a c-shim implementation, using proxy executor as fallback",
50285038
kernel,

0 commit comments

Comments
 (0)