Skip to content

Commit 85ba4ea

Browse files
committed
[WIP][JIT] OpInfo tests for nvfuser
These tests verify that for the same inputs, the eager version of an op and a traced, fused version of the op return the same output. Currently the tests don't check whether or not fusion actually occurred. ghstack-source-id: 54fc4c5 Pull Request resolved: #71299
1 parent 9477f66 commit 85ba4ea

File tree

2 files changed

+80
-9
lines changed

2 files changed

+80
-9
lines changed

test/test_jit_cuda_fuser.py

Lines changed: 68 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,14 @@
1111
import torch
1212
from torch.nn import functional
1313

14-
from torch.testing._internal.common_utils import run_tests, ProfilingMode, GRAPH_EXECUTOR, TEST_WITH_ROCM, IS_WINDOWS
15-
from torch.testing._internal.common_cuda import TEST_MULTIGPU
1614
from torch.testing._internal.codegen.random_topo_test import runDefaultTestWithSeed
17-
from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA
15+
from torch.testing._internal.common_cuda import TEST_MULTIGPU
16+
from torch.testing._internal.common_device_type import instantiate_device_type_tests, ops, OpDTypes
17+
from torch.testing._internal.common_jit import JitCommonTestCase
18+
from torch.testing._internal.common_methods_invocations import op_db
19+
from torch.testing._internal.common_utils import run_tests, ProfilingMode, GRAPH_EXECUTOR, TEST_WITH_ROCM, IS_WINDOWS
20+
from torch.testing._internal.jit_utils import clone_inputs, get_traced_sample_variant_pairs, JitTestCase, RUN_CUDA
21+
from torch.testing._internal.jit_metaprogramming_utils import create_traced_fn
1822
from torch.testing import FileCheck
1923

2024
from jit.test_fuser_common import TestFuserCommon # noqa: F401
@@ -73,6 +77,28 @@ def is_pre_volta():
7377

7478
TEST_BF16 = RUN_NVFUSER and torch.cuda.is_bf16_supported()
7579

80+
class CudaFuserTestOptions():
81+
def __init__(self):
82+
self.old_cpu_fuse = torch._C._jit_can_fuse_on_cpu()
83+
self.old_gpu_fuse = torch._C._jit_can_fuse_on_gpu()
84+
torch._C._jit_override_can_fuse_on_cpu(False)
85+
torch._C._jit_override_can_fuse_on_gpu(False)
86+
self.old_guard = torch._C._jit_set_nvfuser_guard_mode(False)
87+
torch._C._debug_set_autodiff_subgraph_inlining(False)
88+
self.old_value = torch._C._jit_set_autocast_mode(True)
89+
90+
if(RUN_CUDA):
91+
self.old_nvfuser = torch._C._jit_set_nvfuser_enabled(True)
92+
93+
def restore(self):
94+
if(RUN_CUDA):
95+
torch._C._jit_set_nvfuser_enabled(self.old_nvfuser)
96+
torch._C._jit_override_can_fuse_on_cpu(self.old_cpu_fuse)
97+
torch._C._jit_override_can_fuse_on_gpu(self.old_gpu_fuse)
98+
torch._C._jit_set_nvfuser_guard_mode(self.old_guard)
99+
torch._C._debug_set_autodiff_subgraph_inlining(True)
100+
torch._C._jit_set_autocast_mode(self.old_value)
101+
76102
class TestCudaFuser(JitTestCase):
77103
def _getSubgraphInFusion(self, graph):
78104
num_node = 0
@@ -131,15 +157,11 @@ def setUp(self):
131157

132158
if(RUN_NVFUSER):
133159
self.old_nvfuser = torch._C._jit_set_nvfuser_enabled(True)
160+
self.cuda_fuser_options = CudaFuserTestOptions()
134161

135162
def tearDown(self):
136163
if(RUN_NVFUSER):
137-
torch._C._jit_set_nvfuser_enabled(self.old_nvfuser)
138-
torch._C._jit_override_can_fuse_on_cpu(self.old_cpu_fuse)
139-
torch._C._jit_override_can_fuse_on_gpu(self.old_gpu_fuse)
140-
torch._C._jit_set_nvfuser_guard_mode(self.old_guard)
141-
torch._C._debug_set_autodiff_subgraph_inlining(True)
142-
torch._C._jit_set_autocast_mode(self.old_value)
164+
self.cuda_fuser_options.restore()
143165
super(TestCudaFuser, self).tearDown()
144166

145167
def _run_helper(self, jit_op, op, *args):
@@ -4408,5 +4430,42 @@ def test_register_fuser(self):
44084430
self.assertTrue(torch._C._jit_set_nvfuser_enabled(False))
44094431
self.assertFalse(torch._C._jit_nvfuser_enabled())
44104432

4433+
4434+
class TestCudaFuserOpInfo(JitCommonTestCase):
4435+
def setUp(self):
4436+
if RUN_NVFUSER:
4437+
self.cuda_fuser_options = CudaFuserTestOptions()
4438+
self.nvfuser_single_node_mode = torch._C._jit_set_nvfuser_single_node_mode(True)
4439+
4440+
def tearDown(self):
4441+
if RUN_NVFUSER:
4442+
self.cuda_fuser_options.restore()
4443+
torch._C._jit_set_nvfuser_single_node_mode(self.nvfuser_single_node_mode)
4444+
4445+
@slowTest
4446+
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
4447+
@ops(op_db, dtypes=OpDTypes.supported)
4448+
def test_nvfuser_correctness(self, device, dtype, op):
4449+
variant_sample_pairs = get_traced_sample_variant_pairs(device, dtype, op)
4450+
4451+
for variant, sample in variant_sample_pairs:
4452+
trace = create_traced_fn(self, variant)
4453+
ref = variant(*clone_inputs((sample.input, *sample.args)), **sample.kwargs)
4454+
4455+
trace(*clone_inputs((sample.input, *sample.args)), **sample.kwargs)
4456+
4457+
val = trace(*clone_inputs((sample.input, *sample.args)), **sample.kwargs)
4458+
4459+
self.assertEqual(ref, val)
4460+
4461+
# https://github.com/pytorch/pytorch/issues/35600
4462+
# each torch.jit.trace adds state to the _python_cu compilation unit
4463+
# since this test traces a lot of functions, out-of-memory can occur
4464+
# if the CU is not cleared.
4465+
torch.jit._state._python_cu.drop_all_functions()
4466+
4467+
instantiate_device_type_tests(TestCudaFuserOpInfo, globals(), only_for=("cuda"))
4468+
4469+
44114470
if __name__ == '__main__':
44124471
run_tests()

torch/testing/_internal/common_methods_invocations.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8648,6 +8648,7 @@ def ref_pairwise_distance(input1, input2):
86488648
# https://github.com/pytorch/pytorch/issues/71784
86498649
DecorateInfo(unittest.skip('Skipped!'), 'TestNNCOpInfo', 'test_nnc_correctness',
86508650
device_type='cpu', dtypes=(torch.float16,)),
8651+
DecorateInfo(unittest.skip('Skipped!'), 'TestCudaFuserOpInfo', 'test_nvfuser_correctness', dtypes=(torch.float16,)),
86518652
)),
86528653
OpInfo('addmv',
86538654
dtypes=all_types_and_complex_and(torch.bfloat16),
@@ -8917,6 +8918,7 @@ def ref_pairwise_distance(input1, input2):
89178918
skips=(
89188919
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
89198920
DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'),
8921+
DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo', 'test_nvfuser_correctness'),
89208922
),
89218923
supports_out=False),
89228924
OpInfo('broadcast_to',
@@ -9189,6 +9191,8 @@ def ref_pairwise_distance(input1, input2):
91899191
DecorateInfo(unittest.expectedFailure, "TestCommon", "test_noncontiguous_samples"),
91909192
# RuntimeError: "eq_cpu" not implemented for 'ComplexHalf'
91919193
DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness', dtypes=(torch.half,)),
9194+
# RuntimeError: "eq_cpu" not implemented for 'ComplexHalf'
9195+
DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo', 'test_nvfuser_correctness', dtypes=(torch.half,)),
91929196
)),
91939197
BinaryUfuncInfo('complex',
91949198
dtypes=floating_types_and(torch.half),
@@ -9967,6 +9971,7 @@ def ref_pairwise_distance(input1, input2):
99679971
# Arguments for call are not valid.
99689972
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32, torch.complex64)), # noqa: B950
99699973
DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'),
9974+
DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo', 'test_nvfuser_correctness'),
99709975
),
99719976
supports_inplace_autograd=False,
99729977
sample_inputs_func=sample_inputs_gradient,
@@ -11572,6 +11577,7 @@ def ref_pairwise_distance(input1, input2):
1157211577
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator', device_type='cpu'),
1157311578
# RuntimeError: "max_pool1d_impl" not implemented for 'BFloat16'
1157411579
DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness', dtypes=(torch.bfloat16,)),
11580+
DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo', 'test_nvfuser_correctness', dtypes=(torch.bfloat16,)),
1157511581
),
1157611582
sample_inputs_func=sample_inputs_max_pool),
1157711583
OpInfo('nn.functional.max_pool2d',
@@ -13677,6 +13683,7 @@ def ref_pairwise_distance(input1, input2):
1367713683
# RuntimeError: attribute lookup is not defined on builtin
1367813684
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
1367913685
DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'),
13686+
DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo', 'test_nvfuser_correctness'),
1368013687
)),
1368113688
OpInfo('bfloat16',
1368213689
op=lambda x, *args, **kwargs: x.bfloat16(*args, **kwargs),
@@ -13690,6 +13697,7 @@ def ref_pairwise_distance(input1, input2):
1369013697
# RuntimeError: attribute lookup is not defined on builtin
1369113698
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
1369213699
DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'),
13700+
DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo', 'test_nvfuser_correctness'),
1369313701
)),
1369413702
OpInfo('bool',
1369513703
op=lambda x, *args, **kwargs: x.bool(*args, **kwargs),
@@ -13908,6 +13916,8 @@ def ref_pairwise_distance(input1, input2):
1390813916
DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'),
1390913917
# Empty tensor data is garbage so it's hard to make comparisons with it.
1391013918
DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'),
13919+
# Empty tensor data is garbage so it's hard to make comparisons with it.
13920+
DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo', 'test_nvfuser_correctness'),
1391113921
# Can't find schemas for this operator for some reason
1391213922
DecorateInfo(unittest.skip("Skipped!"), 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'),
1391313923
)),
@@ -14016,6 +14026,8 @@ def ref_pairwise_distance(input1, input2):
1401614026
DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'),
1401714027
# Empty tensor data is garbage so it's hard to make comparisons with it.
1401814028
DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'),
14029+
# Empty tensor data is garbage so it's hard to make comparisons with it.
14030+
DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo', 'test_nvfuser_correctness'),
1401914031
# Can't find schemas for this operator for some reason
1402014032
DecorateInfo(unittest.skip("Skipped!"), 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'),
1402114033
),

0 commit comments

Comments
 (0)