Skip to content

Commit 9ae78a5

Browse files
janselpytorchmergebot
authored andcommitted
[halide-backend] Support manual schedules (#129321)
Currently using this for some by-hand hacking, but might need to implement our own scheduler later. Pull Request resolved: #129321 Approved by: https://github.com/shunting314 ghstack dependencies: #126417, #129025, #129026, #127506, #129036, #129320
1 parent a18eb65 commit 9ae78a5

File tree

4 files changed

+131
-41
lines changed

4 files changed

+131
-41
lines changed

test/inductor/test_halide.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,80 @@ def generate(g):
120120
fn(a, b, c)
121121
self.assertEqual(c, a + b)
122122

123+
def test_manual_schedule(self):
124+
fn = HalideCodeCache.generate_halide(
125+
HalideMeta(
126+
argtypes=[
127+
HalideInputSpec(
128+
ctype="float*",
129+
name="in_ptr0",
130+
shape=["1024L"],
131+
stride=["1L"],
132+
offset="0",
133+
),
134+
HalideInputSpec(
135+
ctype="float*",
136+
name="in_ptr1",
137+
shape=["1024L"],
138+
stride=["1L"],
139+
offset="0",
140+
),
141+
HalideInputSpec(
142+
ctype="float*",
143+
name="out_ptr0",
144+
shape=["1024L"],
145+
stride=["1L"],
146+
offset="0",
147+
),
148+
],
149+
target="host-no_runtime",
150+
scheduler=None,
151+
),
152+
textwrap.dedent(
153+
"""
154+
import halide as hl
155+
156+
@hl.generator(name="kernel")
157+
class Kernel:
158+
in_ptr0 = hl.InputBuffer(hl.Float(32), 1)
159+
in_ptr1 = hl.InputBuffer(hl.Float(32), 1)
160+
out_ptr0 = hl.OutputBuffer(hl.Float(32), 1)
161+
162+
def generate(g):
163+
in_ptr0 = g.in_ptr0
164+
in_ptr1 = g.in_ptr1
165+
out_ptr0 = g.out_ptr0
166+
xindex = hl.Var('xindex')
167+
x0 = xindex
168+
tmp0 = hl.Func()
169+
tmp0[xindex] = in_ptr0[x0]
170+
tmp1 = hl.Func()
171+
tmp1[xindex] = in_ptr1[x0]
172+
tmp2 = hl.Func()
173+
tmp2[xindex] = tmp0[xindex] + tmp1[xindex]
174+
out_ptr0[x0] = tmp2[xindex]
175+
176+
assert not g.using_autoscheduler()
177+
i = hl.Var()
178+
j = hl.Var()
179+
out_ptr0.compute_root()
180+
out_ptr0.split(xindex, i, j, 32)
181+
out_ptr0.parallel(i)
182+
out_ptr0.vectorize(j)
183+
tmp2.compute_at(out_ptr0, i)
184+
tmp2.store_at(out_ptr0, i)
185+
tmp1.compute_inline()
186+
187+
__name__ == '__main__' and hl.main()
188+
"""
189+
),
190+
)
191+
a = torch.randn(1024)
192+
b = torch.randn(1024)
193+
c = torch.randn(1024)
194+
fn(a, b, c)
195+
self.assertEqual(c, a + b)
196+
123197

124198
if test_torchinductor.HAS_CPU and HAS_HALIDE:
125199
SweepInputsCpuHalideTest = make_halide(test_torchinductor.SweepInputsCpuTest)

torch/_inductor/codecache.py

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2814,26 +2814,23 @@ def generate_halide_async(cls, meta: HalideMeta, source_code: str, submit_fn=Non
28142814
jobs = []
28152815
if need_compile:
28162816
write_atomic(genfile, source_code)
2817-
jobs.append(
2818-
functools.partial(
2819-
subprocess.check_call,
2820-
[
2821-
sys.executable,
2822-
genfile,
2823-
"-g",
2824-
"kernel",
2825-
"-o",
2826-
f"{dirpath}",
2827-
"-f",
2828-
"halide_kernel",
2829-
"-e",
2830-
"static_library,h,schedule,conceptual_stmt",
2831-
"-p",
2832-
cls.find_libautoschedule(meta.scheduler),
2833-
*meta.args(),
2834-
],
2835-
)
2836-
)
2817+
cmd = [
2818+
sys.executable,
2819+
genfile,
2820+
"-g",
2821+
"kernel",
2822+
"-o",
2823+
f"{dirpath}",
2824+
"-f",
2825+
"halide_kernel",
2826+
"-e",
2827+
"static_library,h,schedule",
2828+
]
2829+
if meta.scheduler:
2830+
cmd.extend(["-p", cls.find_libautoschedule(meta.scheduler)])
2831+
cmd.extend(meta.args())
2832+
jobs.append(functools.partial(subprocess.check_call, cmd))
2833+
28372834
binding_types = [
28382835
arg.bindings_type() for arg in meta.argtypes if arg.alias_of is None
28392836
]

torch/_inductor/codegen/halide.py

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1474,25 +1474,40 @@ def update_index(m):
14741474

14751475
code.do_unindent(2)
14761476
code.splice(
1477-
f"""
1477+
"""
14781478
if __name__ == "__main__":
14791479
hl.main()
1480-
else:
1481-
hl.load_plugin({HalideCodeCache.find_libautoschedule(meta.scheduler)!r})
1482-
target = hl.Target({meta.target!r})
1483-
autoscheduler = hl.AutoschedulerParams({meta.scheduler!r}, {meta.scheduler_flags!r})
1484-
with hl.GeneratorContext(target, autoscheduler):
1485-
gen = Kernel()
1486-
pipeline = gen._build_pipeline()
1487-
# gen.compile_to_callable() does not run the autoscheduler
1488-
pipeline.apply_autoscheduler(target, autoscheduler)
1489-
kernel = pipeline.compile_to_callable([
1490-
gen._get_input_parameter(a.name)._to_argument()
1491-
for a in gen._get_arginfos()
1492-
if a.dir == hl.ArgInfoDirection.Input
1493-
], target)
1494-
"""
1480+
""".rstrip(),
14951481
)
1482+
if meta.scheduler:
1483+
code.splice(
1484+
f"""
1485+
else:
1486+
hl.load_plugin({HalideCodeCache.find_libautoschedule(meta.scheduler)!r})
1487+
target = hl.Target({meta.target!r})
1488+
autoscheduler = hl.AutoschedulerParams({meta.scheduler!r}, {meta.scheduler_flags!r})
1489+
with hl.GeneratorContext(target, autoscheduler):
1490+
gen = Kernel()
1491+
pipeline = gen._build_pipeline()
1492+
# gen.compile_to_callable() does not run the autoscheduler
1493+
pipeline.apply_autoscheduler(target, autoscheduler)
1494+
kernel = pipeline.compile_to_callable([
1495+
gen._get_input_parameter(a.name)._to_argument()
1496+
for a in gen._get_arginfos()
1497+
if a.dir == hl.ArgInfoDirection.Input
1498+
], target)
1499+
""",
1500+
strip=True,
1501+
)
1502+
else:
1503+
code.splice(
1504+
f"""
1505+
else:
1506+
with hl.GeneratorContext(hl.Target({meta.target!r})):
1507+
kernel = Kernel().compile_to_callable()
1508+
""",
1509+
strip=True,
1510+
)
14961511
return code.getvalue()
14971512

14981513
@staticmethod

torch/_inductor/runtime/hints.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -160,15 +160,19 @@ def is_buffer(self):
160160
class HalideMeta(typing.NamedTuple):
161161
argtypes: List[HalideInputSpec]
162162
target: str
163-
scheduler: str
164-
scheduler_flags: Dict[str, Union[int, str]]
163+
scheduler: Optional[str] = None
164+
scheduler_flags: Optional[Dict[str, Union[int, str]]] = None
165165
cuda_device: Optional[int] = None
166166

167167
def args(self):
168168
"""Command line args to pass to halide generator"""
169-
args = [f"target={self.target}", f"autoscheduler={self.scheduler}"]
170-
for k, v in self.scheduler_flags.items():
171-
args.append(f"autoscheduler.{k}={v}")
169+
args = [f"target={self.target}"]
170+
if self.scheduler:
171+
args.append(f"autoscheduler={self.scheduler}")
172+
if self.scheduler_flags:
173+
assert self.scheduler
174+
for k, v in self.scheduler_flags.items():
175+
args.append(f"autoscheduler.{k}={v}")
172176
return args
173177

174178
def is_cuda(self):

0 commit comments

Comments
 (0)