Skip to content

Commit 1a33bf8

Browse files
committed
[reland][inductor] Add an AOT compilation mode for Inductor CPP backend
Summary: This is a reland of #94822 ghstack-source-id: 5b36e6a Pull Request resolved: #95985
1 parent 4026c62 commit 1a33bf8

File tree

14 files changed

+314
-51
lines changed

14 files changed

+314
-51
lines changed

benchmarks/dynamo/common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ class CI(NamedTuple):
175175
# TIMM
176176
"cait_m36_384", # Accuracy
177177
"pnasnet5large", # OOM
178+
"xcit_large_24_p8_224", # OOM https://github.com/pytorch/pytorch/issues/95984
178179
]
179180

180181
CI_SKIP[CI("inductor", training=True)] = [
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
2+
project(test)
3+
4+
set(Torch_DIR "../../../../torch/share/cmake/Torch")
5+
find_package(Torch REQUIRED)
6+
7+
add_executable(test test.cpp ${CMAKE_BINARY_DIR}/aot_inductor_output.h)
8+
9+
add_custom_command(
10+
OUTPUT ${CMAKE_BINARY_DIR}/aot_inductor_output.h
11+
COMMAND python ${CMAKE_SOURCE_DIR}/test.py
12+
DEPENDS ${CMAKE_SOURCE_DIR}/test.py
13+
)
14+
add_custom_target(generate_header ALL
15+
DEPENDS ${CMAKE_BINARY_DIR}/aot_inductor_output.h)
16+
17+
add_library(aot_inductor_output SHARED IMPORTED)
18+
set_property(TARGET aot_inductor_output PROPERTY
19+
IMPORTED_LOCATION ${CMAKE_BINARY_DIR}/aot_inductor_output.so)
20+
21+
target_link_libraries(test "${TORCH_LIBRARIES}" aot_inductor_output)
22+
23+
set_property(TARGET test PROPERTY CXX_STANDARD 17)

test/inductor/aot/cpp/test.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
//#include <gtest/gtest.h>
2+
#include <iostream>
3+
4+
#include "build/aot_inductor_output.h"
5+
6+
/*
7+
class Net(torch.nn.Module):
8+
def __init__(self):
9+
super().__init__()
10+
self.weight = torch.ones(32, 64)
11+
12+
def forward(self, x):
13+
x = torch.relu(x + self.weight)
14+
return x
15+
*/
16+
struct Net : torch::nn::Module {
17+
Net() {
18+
weight = register_parameter("weight", torch::ones({32, 64}));
19+
}
20+
torch::Tensor forward(torch::Tensor input) {
21+
return torch::relu(input + weight);
22+
}
23+
torch::Tensor weight;
24+
};
25+
26+
int main() {
27+
torch::Tensor x = at::randn({32, 64});
28+
Net net;
29+
torch::Tensor results_ref = net.forward(x);
30+
31+
// TODO: we need to provide an API to concatenate args and weights
32+
std::vector<torch::Tensor> inputs = {x};
33+
for (const auto& pair : net.named_parameters()) {
34+
inputs.push_back(pair.value());
35+
}
36+
torch::Tensor results_opt = aot_inductor_entry(inputs);
37+
38+
assert(torch::allclose(results_ref, results_opt));
39+
printf("PASS\n");
40+
return 0;
41+
}

test/inductor/aot/cpp/test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import torch
2+
import torch._dynamo
3+
import torch._inductor
4+
import torch._inductor.config
5+
6+
torch._inductor.config.aot_codegen_output_prefix = "aot_inductor_output"
7+
8+
9+
class Net(torch.nn.Module):
10+
def __init__(self):
11+
super().__init__()
12+
self.weight = torch.ones(32, 64)
13+
14+
def forward(self, x):
15+
x = torch.relu(x + self.weight)
16+
return x
17+
18+
19+
inp = torch.randn((32, 64), device="cpu")
20+
module, _ = torch._dynamo.export(Net(), inp)
21+
so_path = torch._inductor.aot_compile(module, [inp])
22+
print(so_path)

test/inductor/aot/cpp/test.sh

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#!/bin/bash
2+
set -euxo pipefail
3+
4+
mkdir -p build
5+
cd build
6+
cmake ..
7+
make
8+
./test

torch/_inductor/__init__.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,24 @@ def compile(
2525
from .compile_fx import compile_fx
2626

2727
return compile_fx(gm, example_inputs, config_patches=options)
28+
29+
30+
def aot_compile(
31+
gm: torch.fx.GraphModule,
32+
example_inputs: List[torch.Tensor],
33+
options: Optional[Dict[str, Any]] = None,
34+
) -> str:
35+
"""
36+
Ahead-of-time compile a given FX graph with TorchInductor into a shared library.
37+
38+
Args:
39+
gm: The FX graph to compile.
40+
example_inputs: List of tensor inputs.
41+
options: Optional dict of config options. See `torch._inductor.config`.
42+
43+
Returns:
44+
Path to the generated shared library
45+
"""
46+
from .compile_fx import compile_fx
47+
48+
return compile_fx(gm, example_inputs, config_patches=options, aot_mode=True)()

torch/_inductor/codecache.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -534,6 +534,52 @@ def cpp_compile_command(
534534
).strip()
535535

536536

537+
class AotCodeCache:
538+
cache = dict()
539+
clear = staticmethod(cache.clear)
540+
541+
@classmethod
542+
def compile(cls, source_code):
543+
from .codegen.wrapper import CppWrapperCodeGen
544+
545+
# TODO: update cpp_compile_command for different platforms
546+
picked_vec_isa = pick_vec_isa()
547+
key, input_path = write(
548+
source_code,
549+
"cpp",
550+
code_hash(repr(cpp_compile_command("i", "o", vec_isa=picked_vec_isa))),
551+
)
552+
if key not in cls.cache:
553+
from filelock import FileLock
554+
555+
lock_dir = get_lock_dir()
556+
lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
557+
with lock:
558+
output_so = (
559+
os.path.join(os.getcwd(), f"{config.aot_codegen_output_prefix}.so")
560+
if config.aot_codegen_output_prefix
561+
else f"{input_path[:-3]}.so"
562+
)
563+
564+
output_header = f"{output_so[:-3]}.h"
565+
with open(output_header, "w") as header_file:
566+
header_file.writelines("#include <torch/torch.h>\n\n")
567+
header_file.writelines(f"{CppWrapperCodeGen.decl_str};\n")
568+
569+
log.info(f"AOT-Inductor compiles code into: {output_so}")
570+
if not os.path.exists(output_so):
571+
cmd = cpp_compile_command(
572+
input=input_path, output=output_so, vec_isa=picked_vec_isa
573+
).split(" ")
574+
try:
575+
subprocess.check_output(cmd, stderr=subprocess.STDOUT)
576+
except subprocess.CalledProcessError as e:
577+
raise exc.CppCompileError(cmd, e.output) from e
578+
579+
cls.cache[key] = output_so
580+
return cls.cache[key]
581+
582+
537583
class CppCodeCache:
538584
cache = dict()
539585
clear = staticmethod(cache.clear)

torch/_inductor/codegen/cpp.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2040,7 +2040,12 @@ def codegen_define_and_call(self, wrapper):
20402040
)
20412041
if enable_kernel_profile:
20422042
code.writelines(["#include <ATen/record_function.h>"])
2043-
code.writelines([cpp_prefix(), "" f'extern "C" void kernel({arg_defs})'])
2043+
kernel_decl_name = kernel_name if V.graph.aot_mode else "kernel"
2044+
2045+
if not V.graph.aot_mode or self.count == 1:
2046+
code.writeline(cpp_prefix())
2047+
2048+
code.writeline(f'extern "C" void {kernel_decl_name}({arg_defs})')
20442049
with code.indent():
20452050
if enable_kernel_profile:
20462051
graph_id = V.graph.graph_id
@@ -2055,9 +2060,12 @@ def codegen_define_and_call(self, wrapper):
20552060
code.splice(self.loops_code)
20562061

20572062
codecache_def = IndentedBuffer()
2058-
codecache_def.writeline("async_compile.cpp('''")
2059-
codecache_def.splice(code)
2060-
codecache_def.writeline("''')")
2063+
if V.graph.aot_mode:
2064+
codecache_def.splice(code)
2065+
else:
2066+
codecache_def.writeline("async_compile.cpp('''")
2067+
codecache_def.splice(code)
2068+
codecache_def.writeline("''')")
20612069

20622070
codecache_str = codecache_def.getvalue()
20632071
# TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does

torch/_inductor/codegen/cpp_prefix.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <limits>
66
#include <omp.h>
77

8+
#include <ATen/ATen.h>
89
#include <ATen/core/PhiloxRNGEngine.h>
910
#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2)
1011
#include <ATen/cpu/vec/functional.h>

0 commit comments

Comments
 (0)