Skip to content

[Inductor] Enable Custom op Autotune Decompositions and Parameter Tuning#164212

Closed
tianrengao wants to merge 30 commits intomainfrom
tianren/customOp_autotune
Closed

[Inductor] Enable Custom op Autotune Decompositions and Parameter Tuning#164212
tianrengao wants to merge 30 commits intomainfrom
tianren/customOp_autotune

Conversation

@tianrengao
Copy link
Contributor

@tianrengao tianrengao commented Sep 30, 2025

This PR introduces CustomOp autotuning. It allows user to provide a CustomOpConfig:
(1) to register (optional) multiple decomposition implementations for custom operations and
(2) to register parameter tuning knobs and values they want to tune for the decompositions
so that inductor automatically select the best-performing variant through Inductor's autotune benchmarking.

Example:

 register_custom_op_autotuning(
            custom_op=my_attention_op,
            configs=[
                CustomOpConfig(attention_impl, head_dim=32, method='chunked'),
                CustomOpConfig(attention_impl, head_dim=64, method='tiled'),
                CustomOpConfig(head_dim=128), # no decompositions
            ],
            input_gen_fns={
                "query": lambda fake: torch.randn_like(fake, device='cuda'),
                "key": lambda fake: torch.randn_like(fake, device='cuda'),
                "value": lambda fake: torch.randn_like(fake, device='cuda'),
            }
    )

CustomOpConfig: Each CustomOpConfig defines exactly one autotuning variant with specific parameter values and optional decomposition implementation with PyTorch aten ops. Users can register their own tuning knobs and optional decomposition functions for the same custom operation. The system automatically benchmarks all variants to select the best performing. If no decomposition is provided in the config, the CustomOp's default implementation will be used.

Custom Input Generation: Users can provide custom input generators via an optional input_gen_fns to control how synthetic inputs are created during benchmarking. This enables more realistic performance testing by generating inputs that match expected data distributions and characteristics for each tensor argument.

More Examples with autotune logs::

  1. Allow user to register customOp decompositions with tuning parameters for autotuning. Example usage:
from torch._inductor.kernel.custom_op import CustomOpConfig, register_custom_op_autotuning

def decompose_k_implementation(a: torch.Tensor, b: torch.Tensor, k_splits: int = 4) -> torch.Tensor:
    """Matrix multiply with k-way decomposition."""
         # Implementation...with k_splits

@torch.library.custom_op("my_lib::decompose_k", mutates_args=())
def test_decompose_k_op(
        a: torch.Tensor, b: torch.Tensor, k_splits: int
    ) -> torch.Tensor:
        return decompose_k_implementation(a, b, k_splits)

# Register autotuning with different k_splits values
register_custom_op_autotuning(
    custom_op=test_decompose_k_op,
    configs=[
        CustomOpConfig(decompose_k_implementation, k_splits=2),
        CustomOpConfig(decompose_k_implementation, k_splits=32),
        CustomOpConfig(decompose_k_implementation, k_splits=64), 
        CustomOpConfig(k_splits=128), # can make decomposition optional, then use default impl test_decompose_k_op
        CustomOpConfig(k_splits=256)
    ],
    input_gen_fns={
        "a": lambda fake: torch.randn_like(fake, device='cuda') * 0.1,
        "b": lambda fake: torch.randn_like(fake, device='cuda') * 0.1,
    }
)

Example result:

{"num_choices": 6, "num_triton_choices": 0, "best_kernel": "test_decompose_k_autotuned_fallback_default", "best_time": 0.09980800002813339}
AUTOTUNE test_decompose_k_autotuned(256x65536, 65536x1024)
strides: [65536, 1], [1024, 1]
dtypes: torch.float16, torch.float16
  test_decompose_k_autotuned_fallback_default 0.0998 ms 100.0% 
  test_decompose_k_autotuned_decompose_k_implementation_k_splits_2_0 0.1096 ms 91.0% CustomOp decompose_k_implementation_k_splits_2
  test_decompose_k_autotuned_decompose_k_implementation_k_splits_32_1 0.1277 ms 78.2% CustomOp decompose_k_implementation_k_splits_32
  test_decompose_k_autotuned_decompose_k_implementation_k_splits_64_2 0.1454 ms 68.6% CustomOp decompose_k_implementation_k_splits_64
  test_decompose_k_autotuned_decompose_k_implementation_k_splits_128_3 0.1536 ms 65.0% CustomOp decompose_k_implementation_k_splits_128
  test_decompose_k_autotuned_decompose_k_implementation_k_splits_256_4 0.2084 ms 47.9% CustomOp decompose_k_implementation_k_splits_256
  1. Allow user to tune parameter knob by passing the parameter and values in the CustomOpConfig.
    Example
def mlp_variants(input_tensor, gate_weight, up_weight, down_weight, method):
    """MLP implementation with different computational approaches."""
    if method == 0:
        # Standard separate matmuls
        # ... implementation
    elif method == 1:
        # Batched approach with torch.mm
        # ... implementation
    elif method == 2:
        # Fused weights approach
        # ... implementation

@torch.library.custom_op("my_lib::mlp_op", mutates_args=())
        def mlp_op(
            input_tensor: torch.Tensor,
            gate_weight: torch.Tensor,
            up_weight: torch.Tensor,
            down_weight: torch.Tensor,
            method: int,
        ) -> torch.Tensor:
            return mlp_variants(
                input_tensor, gate_weight, up_weight, down_weight, method=method
            )

register_custom_op_autotuning(
    custom_op=mlp_op,
    configs=[
        CustomOpConfig(method=0),
        CustomOpConfig(method=1),
        CustomOpConfig(method=2),
        # method=0 is the default fallback in the original op
    ],
    input_gen_fns={
        "input_tensor": lambda fake: torch.randn_like(fake, device='cuda') * 0.1,
        "gate_weight": lambda fake: torch.randn_like(fake, device='cuda') * 0.05,
        # ... other input generators
    }
)

Example result:

AUTOTUNE test_mlp_autotuned(4x32x512, 512x1024, 512x1024, 1024x256)
  test_mlp_autotuned_mlp_variants_method_2 0.0181 ms 100.0% CustomOp mlp_variants_method_2
  test_mlp_autotuned_mlp_variants_method_1 0.0185 ms 97.8% CustomOp mlp_variants_method_1
  test_mlp_autotuned_mlp_default_fallback_method_0 0.0198 ms 91.4% CustomOp fallback

Test Suite (test/inductor/test_custom_op_autotune.py)

  • RMSNorm autotuning: Tests different RMSNorm implementations with dynamic input shapes
  • MLP autotuning: Tests different MLP decomposition and tuning "method" parameter
  • DecomposeK: Tests different k_splits values for matrix multiplication decomposition with k dim split
  • Multi-parameter tuning: Tests configs with multiple tuning parameters (scale_mode, chunk_size)

Next Step:

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @mlazos

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 30, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/164212

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 1 Pending

As of commit 87fb258 with merge base 27302a4 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@tianrengao tianrengao force-pushed the tianren/customOp_autotune branch from d68cab4 to ffc1707 Compare September 30, 2025 22:55
@tianrengao tianrengao changed the title initial tests and modified custom op kernel template [inductor] Enable Custom op Autotune Decompositions Oct 1, 2025
@tianrengao tianrengao requested a review from eellison October 2, 2025 17:44
@tianrengao tianrengao changed the title [inductor] Enable Custom op Autotune Decompositions [Inductor] Enable Custom op Autotune Decompositions Oct 2, 2025
@tianrengao tianrengao marked this pull request as ready for review October 3, 2025 17:59
Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good first start !

Would be great to consolidate a bit of the logic with existing subgraph choice.
Still needs to be handled:

  • stride output logic validation
  • register fn for input generation

@tianrengao
Copy link
Contributor Author

good first start !

Would be great to consolidate a bit of the logic with existing subgraph choice. Still needs to be handled:

  • stride output logic validation
  • register fn for input generation

Thanks for the review! @eellison
I added the stride output logic validation and register fn for input generation. Also the code structure is adjusted a bit to improve readability. I also resolved comments and updated tests. Now the default implementation of customop is also used as a choice for benchmark.

@eellison eellison requested a review from BoyuanFeng October 7, 2025 14:51
@meta-codesync
Copy link

meta-codesync bot commented Oct 7, 2025

@tianrengao has imported this pull request. If you are a Meta employee, you can view this in D84080471.

@BoyuanFeng BoyuanFeng requested a review from zou3519 October 8, 2025 04:52
@zou3519
Copy link
Contributor

zou3519 commented Oct 9, 2025

@youkaichao thoughts on the API?

Comment on lines 207 to 215
@register_custom_op_autotuning(op_object.default)
def _(input_tensor, weight, eps: float = 1e-8, default_impl=None):
return autotune_custom_op(
name="test_rmsnorm_autotuned",
decompositions=decompositions,
inputs=[input_tensor, weight],
kwargs={"eps": eps},
default_impl=default_impl,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In your PR body, the decorator is named register_custom_op_lowering. That made sense to me: we are specifying in the lowering that inductor has an option to do autotuning.

If the decorator is named register_custom_op_autotuning, I don't see why we also have to call a function called autotune_custom_op. What else can one write in this function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I agree that autotune_custom_op seems redundant as an API for users if we already have a decorator. The main purpose was to expose an argument of decomposition list for the users. Let me think about if I can unify them into a one-stop decorator, either with a list of decompositions or allowing calling register_custom_op_lowering multiple times for multiple implementations, and hide the autotune_custom_op.

Copy link
Contributor

@zou3519 zou3519 Oct 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tianrengao did you do anything in response to this comment? I see the PR description hasn't been updated

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, updated the PR description.

Copy link
Collaborator

@youkaichao youkaichao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

High-level questions:

Users can register several decomposition functions for the same custom operation

by "decomposition functions", do you mean the decomposition functions will still be traced by Dynamo/Inductor? what variants do you use to benchmark? what variants are targeted here? are they simple python code that just run as-is, or still go through Dynamo/Inductor optimizations (and at which level?)

Do you think we should give users the control over how to benchmark each variants? e.g. for some communication ops, we might need to do a sync before measuring anything.

Does PyTorch provide autotuning cache? And what are the cache keys and cache values?

@tianrengao
Copy link
Contributor Author

tianrengao commented Oct 13, 2025

@youkaichao Great questions! Let me address each one with concrete details:

by "decomposition functions", do you mean the decomposition functions will still be traced by Dynamo/Inductor? what variants do you use to benchmark? what variants are targeted here? are they simple python code that just run as-is, or still go through Dynamo/Inductor optimizations (and at which level?)

Yes, decomposition functions are fully traced by Dynamo/Inductor. But optimizations are limited Here's the exact flow:

  • Flow: Python functions → make_fx → FX graph → Inductor compilation (no fusion, Triton kernels) → benchmark competition
  • Decomposition functions are pure Python code that users write (e.g., rmsnorm_decomposition1, rmsnorm_decomposition2) which are numerically equivalent but implemented with different aten ops or orderings. During autotuning, each decomposition variant gets traced by make_fx to create a FX graph and is compiled by Inductor as a choice for benchmark. The compiled variants compete in benchmarking to determine the fastest.
  • Limitations:
    • No fusion with surrounding ops yet (planned for next PR)
    • We allow users to provide tuning knobs (method, k_split, etc.). If the user does not provide tuning parameters, each variant runs as is.

Example:
Users can register variants as following:

def rmsnorm_decomposition1(x, weight): # Naive approach
def rmsnorm_decomposition2(x, weight): # vLLM-style computation  
def rmsnorm_decomposition3(x, weight): # Some other fancy approach

register_custom_op_autotuning(
    decompositions=[rmsnorm_decomposition1, rmsnorm_decomposition2, rmsnorm_decomposition3]
)

then the inductor will perform autotune and select decomposition2

AUTOTUNE demo_rmsnorm_autotuned(8x1024x4096, 4096)
strides: [4194304, 4096, 1], [1]
dtypes: torch.float16, torch.float16
  demo_rmsnorm_autotuned_rmsnorm_decomposition2_1 0.124ms 100.0% ← Winner
  demo_rmsnorm_autotuned_rmsnorm_decomposition1_0 0.125ms 99.9%
  demo_rmsnorm_autotuned_rmsnorm_decomposition3_2 0.126ms 98.4%
  demo_rmsnorm_autotuned_fallback_default        0.140ms 88.6%
SingleProcess AUTOTUNE benchmarking takes 0.594 seconds for 4 choices

What variants are targeted:

  • Mathematical equivalents with different computational patterns (e.g., rmsnorm with different implementations; matmul with different k_split values)
  • Algorithmic variants via optional tuning_knob argument. Examples:
    • Algorithm choices: method=[0,1,2,3] for different mathematical approaches.
    • Precision control: use_fp32=[True,False] for mixed precision strategies

Example:

def attention_variants(q, k, v, method=0, use_fp32=False):
    if method == 0: # Standard attention
    elif method == 1: # Flash Attention  
    elif method == 2: # Flex Attention

register_custom_op_autotuning(
    custom_op=torch.ops.mylib.attention.default,
    decompositions=[attention_variants],
    tuning_knob={"method": [0, 1, 2], "use_fp32": [True, False]}  
)

The parametric API automatically generates specialized variants (normalization_variants_method_0, normalization_variants_method_1, normalization_variants_method_2) that compete in autotuning, enabling systematic exploration of the performance space.

Do you think we should give users the control over how to benchmark each variants? e.g. for some communication ops, we might need to do a sync before measuring anything.

Currently limited but your sync example is a good point:

Supported: Custom input generators via input_gen_fns, standard GPU/CPU timing
Missing: Pre/post benchmark hooks

Does PyTorch provide autotuning cache? And what are the cache keys and cache values?

Yes, with persistent caching:

Cache Keys: Input shapes/strides/dtypes + decomposition code hashes + device info
Cache Values: Best config + timing measurements

Example of RMSnorm test in test suit:

# First run: 2.150 seconds (includes autotuning)
 # Second run: 0.180 seconds (cache)

This summarizes what we have in this PR right now.
cc: @zou3519

@tianrengao tianrengao changed the title [Inductor] Enable Custom op Autotune Decompositions [Inductor] Enable Custom op Autotune Decompositions and Parameter Tuning Oct 14, 2025
op_object = getattr(getattr(torch.ops, lib_name), op_name)

# Use parameter tuning to test different k_splits values
register_custom_op_autotuning(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to have a knob for whether or not max-autotune is enabled ? potentially there are some cases where we'd want to opt in even if it weren't enabled globally, or vice versa. cc @zou3519

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes sense to me to add a max-autotune knob, for example in this MM split-k case.
Should the user provide the max-autotune configurations themselves?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the suggestion that a user should be able to specify configs to autotune when we're under max-autotune? If so, then the implication for this PR is that the custom op autotuning runs by default (without max-autotune)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zou3519
Custom op autotuning runs by default for the parameters registered via CustomOpConfig, even if max-autotune is not enabled globally.

To enable max-autotune, users can specify a "max_autotune_config" parameter with detailed values in CustomOpConfig. I'm considering implementing this in the next PR.

Does this approach sound reasonable to you?

Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another round of comments - cc @zou3519 , @BoyuanFeng please take a look

@tianrengao tianrengao force-pushed the tianren/customOp_autotune branch from c992543 to 5dcc977 Compare October 27, 2025 02:28
@tianrengao tianrengao requested a review from zou3519 October 27, 2025 03:33
@mlazos mlazos self-requested a review October 27, 2025 20:29
Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, I like the new explicit api.

CustomOpConfig(attention_impl, head_dim=64, method='tiled'),
CustomOpConfig(fallback_impl), # No params
],
input_gen_fns={
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not clear to me what these are for - add a comment? Do we use this to generate one or multiple example inputs? If we generate multiple example inputs, what exactly is auto tuning benchmarking? (Best config with mean performance?)

Copy link
Contributor Author

@tianrengao tianrengao Oct 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is mentioned in the PR description.
The input_gen_fns is used for simulating user's real data distribution during benchmarking. During compile time, we use random data by default for benchmarking. This API allow users to provide fake tensors as inputs and generates ideal distribution. This not generating multiple example inputs. This function is used to reduce any gap between the real data and benchmark data.

self.assertEqual(
compiled_result.shape, expected.shape, f"{test_name} shape mismatch"
)
torch.testing.assert_close(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not self.assertEqual? That's how we test the other stuff in PyTorch

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Will change to assertEqual later

if name != reference_name:
rtol = 1e-1 if "Approximated" in name else 1e-2
atol = 1e-1 if "Approximated" in name else 1e-2
torch.testing.assert_close(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

@facebook-github-bot
Copy link
Contributor

@pytorchbot merge

(Initiating merge automatically since Phabricator Diff has merged)

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 31, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

BoyuanFeng pushed a commit that referenced this pull request Oct 31, 2025
…ing (#164212)

This PR introduces CustomOp autotuning. It allows user to provide a CustomOpConfig:
(1) to register (optional) multiple decomposition implementations for custom operations and
(2) to register parameter tuning knobs and values they want to tune for the decompositions
so that inductor automatically select the best-performing variant through Inductor's autotune benchmarking.

Example:
```python
 register_custom_op_autotuning(
            custom_op=my_attention_op,
            configs=[
                CustomOpConfig(attention_impl, head_dim=32, method='chunked'),
                CustomOpConfig(attention_impl, head_dim=64, method='tiled'),
                CustomOpConfig(head_dim=128), # no decompositions
            ],
            input_gen_fns={
                "query": lambda fake: torch.randn_like(fake, device='cuda'),
                "key": lambda fake: torch.randn_like(fake, device='cuda'),
                "value": lambda fake: torch.randn_like(fake, device='cuda'),
            }
    )
```

**CustomOpConfig**: Each CustomOpConfig defines exactly one autotuning variant with specific parameter values and optional decomposition implementation with PyTorch aten ops. Users can register their own tuning knobs and optional decomposition functions for the same custom operation. The system automatically benchmarks all variants to select the best performing. If no decomposition is provided in the config, the CustomOp's default implementation will be used.

**Custom Input Generation**: Users can provide custom input generators via an optional `input_gen_fns` to control how synthetic inputs are created during benchmarking. This enables more realistic performance testing by generating inputs that match expected data distributions and characteristics for each tensor argument.

**More Examples with autotune logs:**:
1. Allow user to register customOp decompositions with tuning parameters for autotuning. Example usage:
```python
from torch._inductor.kernel.custom_op import CustomOpConfig, register_custom_op_autotuning

def decompose_k_implementation(a: torch.Tensor, b: torch.Tensor, k_splits: int = 4) -> torch.Tensor:
    """Matrix multiply with k-way decomposition."""
         # Implementation...with k_splits

@torch.library.custom_op("my_lib::decompose_k", mutates_args=())
def test_decompose_k_op(
        a: torch.Tensor, b: torch.Tensor, k_splits: int
    ) -> torch.Tensor:
        return decompose_k_implementation(a, b, k_splits)

# Register autotuning with different k_splits values
register_custom_op_autotuning(
    custom_op=test_decompose_k_op,
    configs=[
        CustomOpConfig(decompose_k_implementation, k_splits=2),
        CustomOpConfig(decompose_k_implementation, k_splits=32),
        CustomOpConfig(decompose_k_implementation, k_splits=64),
        CustomOpConfig(k_splits=128), # can make decomposition optional, then use default impl test_decompose_k_op
        CustomOpConfig(k_splits=256)
    ],
    input_gen_fns={
        "a": lambda fake: torch.randn_like(fake, device='cuda') * 0.1,
        "b": lambda fake: torch.randn_like(fake, device='cuda') * 0.1,
    }
)
```

Example result:
```
{"num_choices": 6, "num_triton_choices": 0, "best_kernel": "test_decompose_k_autotuned_fallback_default", "best_time": 0.09980800002813339}
AUTOTUNE test_decompose_k_autotuned(256x65536, 65536x1024)
strides: [65536, 1], [1024, 1]
dtypes: torch.float16, torch.float16
  test_decompose_k_autotuned_fallback_default 0.0998 ms 100.0%
  test_decompose_k_autotuned_decompose_k_implementation_k_splits_2_0 0.1096 ms 91.0% CustomOp decompose_k_implementation_k_splits_2
  test_decompose_k_autotuned_decompose_k_implementation_k_splits_32_1 0.1277 ms 78.2% CustomOp decompose_k_implementation_k_splits_32
  test_decompose_k_autotuned_decompose_k_implementation_k_splits_64_2 0.1454 ms 68.6% CustomOp decompose_k_implementation_k_splits_64
  test_decompose_k_autotuned_decompose_k_implementation_k_splits_128_3 0.1536 ms 65.0% CustomOp decompose_k_implementation_k_splits_128
  test_decompose_k_autotuned_decompose_k_implementation_k_splits_256_4 0.2084 ms 47.9% CustomOp decompose_k_implementation_k_splits_256
```

2. Allow user to tune parameter knob by passing the parameter and values in the CustomOpConfig.
**Example**
```python
def mlp_variants(input_tensor, gate_weight, up_weight, down_weight, method):
    """MLP implementation with different computational approaches."""
    if method == 0:
        # Standard separate matmuls
        # ... implementation
    elif method == 1:
        # Batched approach with torch.mm
        # ... implementation
    elif method == 2:
        # Fused weights approach
        # ... implementation

@torch.library.custom_op("my_lib::mlp_op", mutates_args=())
        def mlp_op(
            input_tensor: torch.Tensor,
            gate_weight: torch.Tensor,
            up_weight: torch.Tensor,
            down_weight: torch.Tensor,
            method: int,
        ) -> torch.Tensor:
            return mlp_variants(
                input_tensor, gate_weight, up_weight, down_weight, method=method
            )

register_custom_op_autotuning(
    custom_op=mlp_op,
    configs=[
        CustomOpConfig(method=0),
        CustomOpConfig(method=1),
        CustomOpConfig(method=2),
        # method=0 is the default fallback in the original op
    ],
    input_gen_fns={
        "input_tensor": lambda fake: torch.randn_like(fake, device='cuda') * 0.1,
        "gate_weight": lambda fake: torch.randn_like(fake, device='cuda') * 0.05,
        # ... other input generators
    }
)

```

Example result:
```
AUTOTUNE test_mlp_autotuned(4x32x512, 512x1024, 512x1024, 1024x256)
  test_mlp_autotuned_mlp_variants_method_2 0.0181 ms 100.0% CustomOp mlp_variants_method_2
  test_mlp_autotuned_mlp_variants_method_1 0.0185 ms 97.8% CustomOp mlp_variants_method_1
  test_mlp_autotuned_mlp_default_fallback_method_0 0.0198 ms 91.4% CustomOp fallback
```

### Test Suite (`test/inductor/test_custom_op_autotune.py`)

*   **RMSNorm autotuning**: Tests different RMSNorm implementations with dynamic input shapes
*   **MLP autotuning**: Tests different MLP decomposition and tuning "method" parameter
*   **DecomposeK**: Tests different k_splits values for matrix multiplication decomposition with k dim split
*   **Multi-parameter tuning**: Tests configs with multiple tuning parameters (scale_mode, chunk_size)

### Next Step:
- Enable Max-autotune with user passed in max-autotune config. https://github.com/pytorch/pytorch/pull/165526/files
- Support inline epilogue fusion for selected best customop decomposition with surrounding elementwise ops. https://github.com/pytorch/pytorch/pull/165952/files
- Support customop autotune considering fusion with multiTemplateBuffer. WIP

Pull Request resolved: #164212
Approved by: https://github.com/zou3519
etaf pushed a commit to etaf/pytorch-inductor-xpu that referenced this pull request Nov 4, 2025
…ing (pytorch#164212)

This PR introduces CustomOp autotuning. It allows user to provide a CustomOpConfig:
(1) to register (optional) multiple decomposition implementations for custom operations and
(2) to register parameter tuning knobs and values they want to tune for the decompositions
so that inductor automatically select the best-performing variant through Inductor's autotune benchmarking.

Example:
```python
 register_custom_op_autotuning(
            custom_op=my_attention_op,
            configs=[
                CustomOpConfig(attention_impl, head_dim=32, method='chunked'),
                CustomOpConfig(attention_impl, head_dim=64, method='tiled'),
                CustomOpConfig(head_dim=128), # no decompositions
            ],
            input_gen_fns={
                "query": lambda fake: torch.randn_like(fake, device='cuda'),
                "key": lambda fake: torch.randn_like(fake, device='cuda'),
                "value": lambda fake: torch.randn_like(fake, device='cuda'),
            }
    )
```

**CustomOpConfig**: Each CustomOpConfig defines exactly one autotuning variant with specific parameter values and optional decomposition implementation with PyTorch aten ops. Users can register their own tuning knobs and optional decomposition functions for the same custom operation. The system automatically benchmarks all variants to select the best performing. If no decomposition is provided in the config, the CustomOp's default implementation will be used.

**Custom Input Generation**: Users can provide custom input generators via an optional `input_gen_fns` to control how synthetic inputs are created during benchmarking. This enables more realistic performance testing by generating inputs that match expected data distributions and characteristics for each tensor argument.

**More Examples with autotune logs:**:
1. Allow user to register customOp decompositions with tuning parameters for autotuning. Example usage:
```python
from torch._inductor.kernel.custom_op import CustomOpConfig, register_custom_op_autotuning

def decompose_k_implementation(a: torch.Tensor, b: torch.Tensor, k_splits: int = 4) -> torch.Tensor:
    """Matrix multiply with k-way decomposition."""
         # Implementation...with k_splits

@torch.library.custom_op("my_lib::decompose_k", mutates_args=())
def test_decompose_k_op(
        a: torch.Tensor, b: torch.Tensor, k_splits: int
    ) -> torch.Tensor:
        return decompose_k_implementation(a, b, k_splits)

# Register autotuning with different k_splits values
register_custom_op_autotuning(
    custom_op=test_decompose_k_op,
    configs=[
        CustomOpConfig(decompose_k_implementation, k_splits=2),
        CustomOpConfig(decompose_k_implementation, k_splits=32),
        CustomOpConfig(decompose_k_implementation, k_splits=64),
        CustomOpConfig(k_splits=128), # can make decomposition optional, then use default impl test_decompose_k_op
        CustomOpConfig(k_splits=256)
    ],
    input_gen_fns={
        "a": lambda fake: torch.randn_like(fake, device='cuda') * 0.1,
        "b": lambda fake: torch.randn_like(fake, device='cuda') * 0.1,
    }
)
```

Example result:
```
{"num_choices": 6, "num_triton_choices": 0, "best_kernel": "test_decompose_k_autotuned_fallback_default", "best_time": 0.09980800002813339}
AUTOTUNE test_decompose_k_autotuned(256x65536, 65536x1024)
strides: [65536, 1], [1024, 1]
dtypes: torch.float16, torch.float16
  test_decompose_k_autotuned_fallback_default 0.0998 ms 100.0%
  test_decompose_k_autotuned_decompose_k_implementation_k_splits_2_0 0.1096 ms 91.0% CustomOp decompose_k_implementation_k_splits_2
  test_decompose_k_autotuned_decompose_k_implementation_k_splits_32_1 0.1277 ms 78.2% CustomOp decompose_k_implementation_k_splits_32
  test_decompose_k_autotuned_decompose_k_implementation_k_splits_64_2 0.1454 ms 68.6% CustomOp decompose_k_implementation_k_splits_64
  test_decompose_k_autotuned_decompose_k_implementation_k_splits_128_3 0.1536 ms 65.0% CustomOp decompose_k_implementation_k_splits_128
  test_decompose_k_autotuned_decompose_k_implementation_k_splits_256_4 0.2084 ms 47.9% CustomOp decompose_k_implementation_k_splits_256
```

2. Allow user to tune parameter knob by passing the parameter and values in the CustomOpConfig.
**Example**
```python
def mlp_variants(input_tensor, gate_weight, up_weight, down_weight, method):
    """MLP implementation with different computational approaches."""
    if method == 0:
        # Standard separate matmuls
        # ... implementation
    elif method == 1:
        # Batched approach with torch.mm
        # ... implementation
    elif method == 2:
        # Fused weights approach
        # ... implementation

@torch.library.custom_op("my_lib::mlp_op", mutates_args=())
        def mlp_op(
            input_tensor: torch.Tensor,
            gate_weight: torch.Tensor,
            up_weight: torch.Tensor,
            down_weight: torch.Tensor,
            method: int,
        ) -> torch.Tensor:
            return mlp_variants(
                input_tensor, gate_weight, up_weight, down_weight, method=method
            )

register_custom_op_autotuning(
    custom_op=mlp_op,
    configs=[
        CustomOpConfig(method=0),
        CustomOpConfig(method=1),
        CustomOpConfig(method=2),
        # method=0 is the default fallback in the original op
    ],
    input_gen_fns={
        "input_tensor": lambda fake: torch.randn_like(fake, device='cuda') * 0.1,
        "gate_weight": lambda fake: torch.randn_like(fake, device='cuda') * 0.05,
        # ... other input generators
    }
)

```

Example result:
```
AUTOTUNE test_mlp_autotuned(4x32x512, 512x1024, 512x1024, 1024x256)
  test_mlp_autotuned_mlp_variants_method_2 0.0181 ms 100.0% CustomOp mlp_variants_method_2
  test_mlp_autotuned_mlp_variants_method_1 0.0185 ms 97.8% CustomOp mlp_variants_method_1
  test_mlp_autotuned_mlp_default_fallback_method_0 0.0198 ms 91.4% CustomOp fallback
```

### Test Suite (`test/inductor/test_custom_op_autotune.py`)

*   **RMSNorm autotuning**: Tests different RMSNorm implementations with dynamic input shapes
*   **MLP autotuning**: Tests different MLP decomposition and tuning "method" parameter
*   **DecomposeK**: Tests different k_splits values for matrix multiplication decomposition with k dim split
*   **Multi-parameter tuning**: Tests configs with multiple tuning parameters (scale_mode, chunk_size)

### Next Step:
- Enable Max-autotune with user passed in max-autotune config. https://github.com/pytorch/pytorch/pull/165526/files
- Support inline epilogue fusion for selected best customop decomposition with surrounding elementwise ops. https://github.com/pytorch/pytorch/pull/165952/files
- Support customop autotune considering fusion with multiTemplateBuffer. WIP

Pull Request resolved: pytorch#164212
Approved by: https://github.com/zou3519
@github-actions github-actions bot deleted the tianren/customOp_autotune branch December 1, 2025 02:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants