Skip to content

compile_kernel enable pch#162972

Closed
msaroufim wants to merge 14 commits intomainfrom
msaroufim/cub
Closed

compile_kernel enable pch#162972
msaroufim wants to merge 14 commits intomainfrom
msaroufim/cub

Conversation

@msaroufim
Copy link
Member

@msaroufim msaroufim commented Sep 15, 2025

Enabling automatic pre compiled headers per https://docs.nvidia.com/cuda/nvrtc/index.html#example-automatic-pch-cuda-12-8

I'm seeing large speedups in compilation times using PCH on average but the max compilation time with PCH is worst which is why I can't enable it by default. load_inline() also supports precompiled headers and does not enable them by default

Without PCH: 270.58 ms average
With PCH:    115.27 ms average
Without PCH: Max: 337.99 ms
With PCH: Max: 383.82 ms
source) [marksaroufim@devgpu005]~/pytorch% python simple_pch_benchmark.py
============================================================
Simple PCH Compilation Benchmark
============================================================
Device: NVIDIA B200
Iterations: 100

Testing WITHOUT PCH:
------------------------------
Compiling kernel 100 times WITHOUT PCH...
  Completed 10/100 compilations
  Completed 20/100 compilations
  Completed 30/100 compilations
  Completed 40/100 compilations
  Completed 50/100 compilations
  Completed 60/100 compilations
  Completed 70/100 compilations
  Completed 80/100 compilations
  Completed 90/100 compilations
  Completed 100/100 compilations
Average: 270.58 ms6.99 ms)
Min: 264.09 ms
Max: 337.99 ms

Testing WITH PCH:
------------------------------
Compiling kernel 100 times WITH PCH...
  Completed 10/100 compilations
  Completed 20/100 compilations
  Completed 30/100 compilations
  Completed 40/100 compilations
  Completed 50/100 compilations
  Completed 60/100 compilations
  Completed 70/100 compilations
  Completed 80/100 compilations
  Completed 90/100 compilations
  Completed 100/100 compilations
Average: 115.27 ms27.32 ms)
Min: 110.65 ms
Max: 383.82 ms

Benchmarking script

#!/usr/bin/env python3
import argparse
import os
import sys
import time
from statistics import mean, stdev

import torch
from torch.cuda._utils import _nvrtc_compile


def benchmark_compilation(use_pch, iterations=100):
    """Compile the same kernel many times with or without PCH."""

    # CUB kernel that benefits from PCH
    kernel_source = """
    #include <cub/block/block_reduce.cuh>
    #include <cub/block/block_scan.cuh>
    #include <cub/warp/warp_reduce.cuh>

    extern "C"
    __global__ void test_kernel(const float* input, float* output, int n) {
        using BlockReduce = cub::BlockReduce<float, 256>;
        using BlockScan = cub::BlockScan<float, 256>;
        using WarpReduce = cub::WarpReduce<float>;

        __shared__ union {
            typename BlockReduce::TempStorage reduce;
            typename BlockScan::TempStorage scan;
            typename WarpReduce::TempStorage warp[8];
        } temp_storage;

        int idx = blockIdx.x * blockDim.x + threadIdx.x;
        float val = (idx < n) ? input[idx] : 0.0f;

        float sum = BlockReduce(temp_storage.reduce).Sum(val);
        __syncthreads();

        float scan_result;
        BlockScan(temp_storage.scan).ExclusiveSum(val, scan_result);
        __syncthreads();

        int warp_id = threadIdx.x / 32;
        float warp_sum = WarpReduce(temp_storage.warp[warp_id]).Sum(val);

        if (threadIdx.x == 0) {
            output[blockIdx.x] = sum + scan_result + warp_sum;
        }
    }
    """

    device = torch.cuda.current_device()
    major, minor = torch.cuda.get_device_capability(device)
    compute_capability = f"{major}{minor}"

    compile_times = []

    print(
        f"Compiling kernel {iterations} times {'WITH' if use_pch else 'WITHOUT'} PCH..."
    )

    for i in range(iterations):
        # Use unique kernel name to avoid caching between iterations
        kernel_name = f"test_kernel_{i}"
        unique_source = kernel_source.replace("test_kernel", kernel_name)

        start = time.perf_counter()

        ptx, mangled_name = _nvrtc_compile(
            unique_source,
            kernel_name,
            compute_capability,
            header_code="",
            nvcc_options=["-std=c++17"],
            auto_pch=use_pch,
        )

        elapsed = time.perf_counter() - start
        compile_times.append(elapsed * 1000)  # Convert to ms

        # Progress indicator
        if (i + 1) % 10 == 0:
            print(f"  Completed {i + 1}/{iterations} compilations")

    return compile_times


def main():
    parser = argparse.ArgumentParser(description="Simple PCH Compilation Benchmark")
    parser.add_argument("--pch", action="store_true", help="Test with PCH only")
    parser.add_argument("--no-pch", action="store_true", help="Test without PCH only")
    parser.add_argument(
        "--iterations", type=int, default=100, help="Number of compilations"
    )
    args = parser.parse_args()

    print("=" * 60)
    print("Simple PCH Compilation Benchmark")
    print("=" * 60)
    print(f"Device: {torch.cuda.get_device_name()}")
    print(f"Iterations: {args.iterations}")
    print()

    # Determine what to test
    test_both = not args.pch and not args.no_pch

    results = {}

    # Test without PCH
    if args.no_pch or test_both:
        print("Testing WITHOUT PCH:")
        print("-" * 30)
        times_no_pch = benchmark_compilation(use_pch=False, iterations=args.iterations)

        if times_no_pch:
            avg_no_pch = mean(times_no_pch)
            std_no_pch = stdev(times_no_pch) if len(times_no_pch) > 1 else 0
            print(f"Average: {avg_no_pch:.2f} ms (±{std_no_pch:.2f} ms)")
            print(f"Min: {min(times_no_pch):.2f} ms")
            print(f"Max: {max(times_no_pch):.2f} ms")
            results["no_pch"] = avg_no_pch
        print()

    # Test with PCH
    if args.pch or test_both:
        print("Testing WITH PCH:")
        print("-" * 30)
        times_with_pch = benchmark_compilation(
            use_pch=True, iterations=args.iterations
        )

        if times_with_pch:
            avg_with_pch = mean(times_with_pch)
            std_with_pch = stdev(times_with_pch) if len(times_with_pch) > 1 else 0
            print(f"Average: {avg_with_pch:.2f} ms (±{std_with_pch:.2f} ms)")
            print(f"Min: {min(times_with_pch):.2f} ms")
            print(f"Max: {max(times_with_pch):.2f} ms")
            results["pch"] = avg_with_pch
        print()

if __name__ == "__main__":
    main()

cc @gau-nernst

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 15, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/162972

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit cdff0ba with merge base 5dc4e78 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@msaroufim msaroufim added the release notes: cuda release notes category label Sep 15, 2025
@msaroufim msaroufim requested a review from malfet September 15, 2025 18:10
@msaroufim msaroufim marked this pull request as ready for review September 15, 2025 18:10
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!

@msaroufim
Copy link
Member Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 15, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@msaroufim
Copy link
Member Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command
For more information see pytorch-bot wiki.

@msaroufim
Copy link
Member Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
Enabling automatic pre compiled headers per https://docs.nvidia.com/cuda/nvrtc/index.html#example-automatic-pch-cuda-12-8

I'm seeing large speedups in compilation times using PCH on average but the max compilation time with PCH is worst which is why I can't enable it by default. `load_inline()` also supports precompiled headers and does not enable them by default

```
Without PCH: 270.58 ms average
With PCH:    115.27 ms average
```

```
Without PCH: Max: 337.99 ms
With PCH: Max: 383.82 ms
```

```python
source) [marksaroufim@devgpu005]~/pytorch% python simple_pch_benchmark.py
============================================================
Simple PCH Compilation Benchmark
============================================================
Device: NVIDIA B200
Iterations: 100

Testing WITHOUT PCH:
------------------------------
Compiling kernel 100 times WITHOUT PCH...
  Completed 10/100 compilations
  Completed 20/100 compilations
  Completed 30/100 compilations
  Completed 40/100 compilations
  Completed 50/100 compilations
  Completed 60/100 compilations
  Completed 70/100 compilations
  Completed 80/100 compilations
  Completed 90/100 compilations
  Completed 100/100 compilations
Average: 270.58 ms (±6.99 ms)
Min: 264.09 ms
Max: 337.99 ms

Testing WITH PCH:
------------------------------
Compiling kernel 100 times WITH PCH...
  Completed 10/100 compilations
  Completed 20/100 compilations
  Completed 30/100 compilations
  Completed 40/100 compilations
  Completed 50/100 compilations
  Completed 60/100 compilations
  Completed 70/100 compilations
  Completed 80/100 compilations
  Completed 90/100 compilations
  Completed 100/100 compilations
Average: 115.27 ms (±27.32 ms)
Min: 110.65 ms
Max: 383.82 ms

```

## Benchmarking script

```python
#!/usr/bin/env python3
import argparse
import os
import sys
import time
from statistics import mean, stdev

import torch
from torch.cuda._utils import _nvrtc_compile

def benchmark_compilation(use_pch, iterations=100):
    """Compile the same kernel many times with or without PCH."""

    # CUB kernel that benefits from PCH
    kernel_source = """
    #include <cub/block/block_reduce.cuh>
    #include <cub/block/block_scan.cuh>
    #include <cub/warp/warp_reduce.cuh>

    extern "C"
    __global__ void test_kernel(const float* input, float* output, int n) {
        using BlockReduce = cub::BlockReduce<float, 256>;
        using BlockScan = cub::BlockScan<float, 256>;
        using WarpReduce = cub::WarpReduce<float>;

        __shared__ union {
            typename BlockReduce::TempStorage reduce;
            typename BlockScan::TempStorage scan;
            typename WarpReduce::TempStorage warp[8];
        } temp_storage;

        int idx = blockIdx.x * blockDim.x + threadIdx.x;
        float val = (idx < n) ? input[idx] : 0.0f;

        float sum = BlockReduce(temp_storage.reduce).Sum(val);
        __syncthreads();

        float scan_result;
        BlockScan(temp_storage.scan).ExclusiveSum(val, scan_result);
        __syncthreads();

        int warp_id = threadIdx.x / 32;
        float warp_sum = WarpReduce(temp_storage.warp[warp_id]).Sum(val);

        if (threadIdx.x == 0) {
            output[blockIdx.x] = sum + scan_result + warp_sum;
        }
    }
    """

    device = torch.cuda.current_device()
    major, minor = torch.cuda.get_device_capability(device)
    compute_capability = f"{major}{minor}"

    compile_times = []

    print(
        f"Compiling kernel {iterations} times {'WITH' if use_pch else 'WITHOUT'} PCH..."
    )

    for i in range(iterations):
        # Use unique kernel name to avoid caching between iterations
        kernel_name = f"test_kernel_{i}"
        unique_source = kernel_source.replace("test_kernel", kernel_name)

        start = time.perf_counter()

        ptx, mangled_name = _nvrtc_compile(
            unique_source,
            kernel_name,
            compute_capability,
            header_code="",
            nvcc_options=["-std=c++17"],
            auto_pch=use_pch,
        )

        elapsed = time.perf_counter() - start
        compile_times.append(elapsed * 1000)  # Convert to ms

        # Progress indicator
        if (i + 1) % 10 == 0:
            print(f"  Completed {i + 1}/{iterations} compilations")

    return compile_times

def main():
    parser = argparse.ArgumentParser(description="Simple PCH Compilation Benchmark")
    parser.add_argument("--pch", action="store_true", help="Test with PCH only")
    parser.add_argument("--no-pch", action="store_true", help="Test without PCH only")
    parser.add_argument(
        "--iterations", type=int, default=100, help="Number of compilations"
    )
    args = parser.parse_args()

    print("=" * 60)
    print("Simple PCH Compilation Benchmark")
    print("=" * 60)
    print(f"Device: {torch.cuda.get_device_name()}")
    print(f"Iterations: {args.iterations}")
    print()

    # Determine what to test
    test_both = not args.pch and not args.no_pch

    results = {}

    # Test without PCH
    if args.no_pch or test_both:
        print("Testing WITHOUT PCH:")
        print("-" * 30)
        times_no_pch = benchmark_compilation(use_pch=False, iterations=args.iterations)

        if times_no_pch:
            avg_no_pch = mean(times_no_pch)
            std_no_pch = stdev(times_no_pch) if len(times_no_pch) > 1 else 0
            print(f"Average: {avg_no_pch:.2f} ms (±{std_no_pch:.2f} ms)")
            print(f"Min: {min(times_no_pch):.2f} ms")
            print(f"Max: {max(times_no_pch):.2f} ms")
            results["no_pch"] = avg_no_pch
        print()

    # Test with PCH
    if args.pch or test_both:
        print("Testing WITH PCH:")
        print("-" * 30)
        times_with_pch = benchmark_compilation(
            use_pch=True, iterations=args.iterations
        )

        if times_with_pch:
            avg_with_pch = mean(times_with_pch)
            std_with_pch = stdev(times_with_pch) if len(times_with_pch) > 1 else 0
            print(f"Average: {avg_with_pch:.2f} ms (±{std_with_pch:.2f} ms)")
            print(f"Min: {min(times_with_pch):.2f} ms")
            print(f"Max: {max(times_with_pch):.2f} ms")
            results["pch"] = avg_with_pch
        print()

if __name__ == "__main__":
    main()

```

Pull Request resolved: pytorch#162972
Approved by: https://github.com/albanD, https://github.com/janeyx99
mansiag05 pushed a commit to mansiag05/pytorch that referenced this pull request Sep 22, 2025
Enabling automatic pre compiled headers per https://docs.nvidia.com/cuda/nvrtc/index.html#example-automatic-pch-cuda-12-8

I'm seeing large speedups in compilation times using PCH on average but the max compilation time with PCH is worst which is why I can't enable it by default. `load_inline()` also supports precompiled headers and does not enable them by default

```
Without PCH: 270.58 ms average
With PCH:    115.27 ms average
```

```
Without PCH: Max: 337.99 ms
With PCH: Max: 383.82 ms
```

```python
source) [marksaroufim@devgpu005]~/pytorch% python simple_pch_benchmark.py
============================================================
Simple PCH Compilation Benchmark
============================================================
Device: NVIDIA B200
Iterations: 100

Testing WITHOUT PCH:
------------------------------
Compiling kernel 100 times WITHOUT PCH...
  Completed 10/100 compilations
  Completed 20/100 compilations
  Completed 30/100 compilations
  Completed 40/100 compilations
  Completed 50/100 compilations
  Completed 60/100 compilations
  Completed 70/100 compilations
  Completed 80/100 compilations
  Completed 90/100 compilations
  Completed 100/100 compilations
Average: 270.58 ms (±6.99 ms)
Min: 264.09 ms
Max: 337.99 ms

Testing WITH PCH:
------------------------------
Compiling kernel 100 times WITH PCH...
  Completed 10/100 compilations
  Completed 20/100 compilations
  Completed 30/100 compilations
  Completed 40/100 compilations
  Completed 50/100 compilations
  Completed 60/100 compilations
  Completed 70/100 compilations
  Completed 80/100 compilations
  Completed 90/100 compilations
  Completed 100/100 compilations
Average: 115.27 ms (±27.32 ms)
Min: 110.65 ms
Max: 383.82 ms

```

## Benchmarking script

```python
#!/usr/bin/env python3
import argparse
import os
import sys
import time
from statistics import mean, stdev

import torch
from torch.cuda._utils import _nvrtc_compile

def benchmark_compilation(use_pch, iterations=100):
    """Compile the same kernel many times with or without PCH."""

    # CUB kernel that benefits from PCH
    kernel_source = """
    #include <cub/block/block_reduce.cuh>
    #include <cub/block/block_scan.cuh>
    #include <cub/warp/warp_reduce.cuh>

    extern "C"
    __global__ void test_kernel(const float* input, float* output, int n) {
        using BlockReduce = cub::BlockReduce<float, 256>;
        using BlockScan = cub::BlockScan<float, 256>;
        using WarpReduce = cub::WarpReduce<float>;

        __shared__ union {
            typename BlockReduce::TempStorage reduce;
            typename BlockScan::TempStorage scan;
            typename WarpReduce::TempStorage warp[8];
        } temp_storage;

        int idx = blockIdx.x * blockDim.x + threadIdx.x;
        float val = (idx < n) ? input[idx] : 0.0f;

        float sum = BlockReduce(temp_storage.reduce).Sum(val);
        __syncthreads();

        float scan_result;
        BlockScan(temp_storage.scan).ExclusiveSum(val, scan_result);
        __syncthreads();

        int warp_id = threadIdx.x / 32;
        float warp_sum = WarpReduce(temp_storage.warp[warp_id]).Sum(val);

        if (threadIdx.x == 0) {
            output[blockIdx.x] = sum + scan_result + warp_sum;
        }
    }
    """

    device = torch.cuda.current_device()
    major, minor = torch.cuda.get_device_capability(device)
    compute_capability = f"{major}{minor}"

    compile_times = []

    print(
        f"Compiling kernel {iterations} times {'WITH' if use_pch else 'WITHOUT'} PCH..."
    )

    for i in range(iterations):
        # Use unique kernel name to avoid caching between iterations
        kernel_name = f"test_kernel_{i}"
        unique_source = kernel_source.replace("test_kernel", kernel_name)

        start = time.perf_counter()

        ptx, mangled_name = _nvrtc_compile(
            unique_source,
            kernel_name,
            compute_capability,
            header_code="",
            nvcc_options=["-std=c++17"],
            auto_pch=use_pch,
        )

        elapsed = time.perf_counter() - start
        compile_times.append(elapsed * 1000)  # Convert to ms

        # Progress indicator
        if (i + 1) % 10 == 0:
            print(f"  Completed {i + 1}/{iterations} compilations")

    return compile_times

def main():
    parser = argparse.ArgumentParser(description="Simple PCH Compilation Benchmark")
    parser.add_argument("--pch", action="store_true", help="Test with PCH only")
    parser.add_argument("--no-pch", action="store_true", help="Test without PCH only")
    parser.add_argument(
        "--iterations", type=int, default=100, help="Number of compilations"
    )
    args = parser.parse_args()

    print("=" * 60)
    print("Simple PCH Compilation Benchmark")
    print("=" * 60)
    print(f"Device: {torch.cuda.get_device_name()}")
    print(f"Iterations: {args.iterations}")
    print()

    # Determine what to test
    test_both = not args.pch and not args.no_pch

    results = {}

    # Test without PCH
    if args.no_pch or test_both:
        print("Testing WITHOUT PCH:")
        print("-" * 30)
        times_no_pch = benchmark_compilation(use_pch=False, iterations=args.iterations)

        if times_no_pch:
            avg_no_pch = mean(times_no_pch)
            std_no_pch = stdev(times_no_pch) if len(times_no_pch) > 1 else 0
            print(f"Average: {avg_no_pch:.2f} ms (±{std_no_pch:.2f} ms)")
            print(f"Min: {min(times_no_pch):.2f} ms")
            print(f"Max: {max(times_no_pch):.2f} ms")
            results["no_pch"] = avg_no_pch
        print()

    # Test with PCH
    if args.pch or test_both:
        print("Testing WITH PCH:")
        print("-" * 30)
        times_with_pch = benchmark_compilation(
            use_pch=True, iterations=args.iterations
        )

        if times_with_pch:
            avg_with_pch = mean(times_with_pch)
            std_with_pch = stdev(times_with_pch) if len(times_with_pch) > 1 else 0
            print(f"Average: {avg_with_pch:.2f} ms (±{std_with_pch:.2f} ms)")
            print(f"Min: {min(times_with_pch):.2f} ms")
            print(f"Max: {max(times_with_pch):.2f} ms")
            results["pch"] = avg_with_pch
        print()

if __name__ == "__main__":
    main()

```

Pull Request resolved: pytorch#162972
Approved by: https://github.com/albanD, https://github.com/janeyx99
cleonard530 pushed a commit to cleonard530/pytorch that referenced this pull request Sep 22, 2025
Enabling automatic pre compiled headers per https://docs.nvidia.com/cuda/nvrtc/index.html#example-automatic-pch-cuda-12-8

I'm seeing large speedups in compilation times using PCH on average but the max compilation time with PCH is worst which is why I can't enable it by default. `load_inline()` also supports precompiled headers and does not enable them by default

```
Without PCH: 270.58 ms average
With PCH:    115.27 ms average
```

```
Without PCH: Max: 337.99 ms
With PCH: Max: 383.82 ms
```

```python
source) [marksaroufim@devgpu005]~/pytorch% python simple_pch_benchmark.py
============================================================
Simple PCH Compilation Benchmark
============================================================
Device: NVIDIA B200
Iterations: 100

Testing WITHOUT PCH:
------------------------------
Compiling kernel 100 times WITHOUT PCH...
  Completed 10/100 compilations
  Completed 20/100 compilations
  Completed 30/100 compilations
  Completed 40/100 compilations
  Completed 50/100 compilations
  Completed 60/100 compilations
  Completed 70/100 compilations
  Completed 80/100 compilations
  Completed 90/100 compilations
  Completed 100/100 compilations
Average: 270.58 ms (±6.99 ms)
Min: 264.09 ms
Max: 337.99 ms

Testing WITH PCH:
------------------------------
Compiling kernel 100 times WITH PCH...
  Completed 10/100 compilations
  Completed 20/100 compilations
  Completed 30/100 compilations
  Completed 40/100 compilations
  Completed 50/100 compilations
  Completed 60/100 compilations
  Completed 70/100 compilations
  Completed 80/100 compilations
  Completed 90/100 compilations
  Completed 100/100 compilations
Average: 115.27 ms (±27.32 ms)
Min: 110.65 ms
Max: 383.82 ms

```

## Benchmarking script

```python
#!/usr/bin/env python3
import argparse
import os
import sys
import time
from statistics import mean, stdev

import torch
from torch.cuda._utils import _nvrtc_compile

def benchmark_compilation(use_pch, iterations=100):
    """Compile the same kernel many times with or without PCH."""

    # CUB kernel that benefits from PCH
    kernel_source = """
    #include <cub/block/block_reduce.cuh>
    #include <cub/block/block_scan.cuh>
    #include <cub/warp/warp_reduce.cuh>

    extern "C"
    __global__ void test_kernel(const float* input, float* output, int n) {
        using BlockReduce = cub::BlockReduce<float, 256>;
        using BlockScan = cub::BlockScan<float, 256>;
        using WarpReduce = cub::WarpReduce<float>;

        __shared__ union {
            typename BlockReduce::TempStorage reduce;
            typename BlockScan::TempStorage scan;
            typename WarpReduce::TempStorage warp[8];
        } temp_storage;

        int idx = blockIdx.x * blockDim.x + threadIdx.x;
        float val = (idx < n) ? input[idx] : 0.0f;

        float sum = BlockReduce(temp_storage.reduce).Sum(val);
        __syncthreads();

        float scan_result;
        BlockScan(temp_storage.scan).ExclusiveSum(val, scan_result);
        __syncthreads();

        int warp_id = threadIdx.x / 32;
        float warp_sum = WarpReduce(temp_storage.warp[warp_id]).Sum(val);

        if (threadIdx.x == 0) {
            output[blockIdx.x] = sum + scan_result + warp_sum;
        }
    }
    """

    device = torch.cuda.current_device()
    major, minor = torch.cuda.get_device_capability(device)
    compute_capability = f"{major}{minor}"

    compile_times = []

    print(
        f"Compiling kernel {iterations} times {'WITH' if use_pch else 'WITHOUT'} PCH..."
    )

    for i in range(iterations):
        # Use unique kernel name to avoid caching between iterations
        kernel_name = f"test_kernel_{i}"
        unique_source = kernel_source.replace("test_kernel", kernel_name)

        start = time.perf_counter()

        ptx, mangled_name = _nvrtc_compile(
            unique_source,
            kernel_name,
            compute_capability,
            header_code="",
            nvcc_options=["-std=c++17"],
            auto_pch=use_pch,
        )

        elapsed = time.perf_counter() - start
        compile_times.append(elapsed * 1000)  # Convert to ms

        # Progress indicator
        if (i + 1) % 10 == 0:
            print(f"  Completed {i + 1}/{iterations} compilations")

    return compile_times

def main():
    parser = argparse.ArgumentParser(description="Simple PCH Compilation Benchmark")
    parser.add_argument("--pch", action="store_true", help="Test with PCH only")
    parser.add_argument("--no-pch", action="store_true", help="Test without PCH only")
    parser.add_argument(
        "--iterations", type=int, default=100, help="Number of compilations"
    )
    args = parser.parse_args()

    print("=" * 60)
    print("Simple PCH Compilation Benchmark")
    print("=" * 60)
    print(f"Device: {torch.cuda.get_device_name()}")
    print(f"Iterations: {args.iterations}")
    print()

    # Determine what to test
    test_both = not args.pch and not args.no_pch

    results = {}

    # Test without PCH
    if args.no_pch or test_both:
        print("Testing WITHOUT PCH:")
        print("-" * 30)
        times_no_pch = benchmark_compilation(use_pch=False, iterations=args.iterations)

        if times_no_pch:
            avg_no_pch = mean(times_no_pch)
            std_no_pch = stdev(times_no_pch) if len(times_no_pch) > 1 else 0
            print(f"Average: {avg_no_pch:.2f} ms (±{std_no_pch:.2f} ms)")
            print(f"Min: {min(times_no_pch):.2f} ms")
            print(f"Max: {max(times_no_pch):.2f} ms")
            results["no_pch"] = avg_no_pch
        print()

    # Test with PCH
    if args.pch or test_both:
        print("Testing WITH PCH:")
        print("-" * 30)
        times_with_pch = benchmark_compilation(
            use_pch=True, iterations=args.iterations
        )

        if times_with_pch:
            avg_with_pch = mean(times_with_pch)
            std_with_pch = stdev(times_with_pch) if len(times_with_pch) > 1 else 0
            print(f"Average: {avg_with_pch:.2f} ms (±{std_with_pch:.2f} ms)")
            print(f"Min: {min(times_with_pch):.2f} ms")
            print(f"Max: {max(times_with_pch):.2f} ms")
            results["pch"] = avg_with_pch
        print()

if __name__ == "__main__":
    main()

```

Pull Request resolved: pytorch#162972
Approved by: https://github.com/albanD, https://github.com/janeyx99
dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
Enabling automatic pre compiled headers per https://docs.nvidia.com/cuda/nvrtc/index.html#example-automatic-pch-cuda-12-8

I'm seeing large speedups in compilation times using PCH on average but the max compilation time with PCH is worst which is why I can't enable it by default. `load_inline()` also supports precompiled headers and does not enable them by default

```
Without PCH: 270.58 ms average
With PCH:    115.27 ms average
```

```
Without PCH: Max: 337.99 ms
With PCH: Max: 383.82 ms
```

```python
source) [marksaroufim@devgpu005]~/pytorch% python simple_pch_benchmark.py
============================================================
Simple PCH Compilation Benchmark
============================================================
Device: NVIDIA B200
Iterations: 100

Testing WITHOUT PCH:
------------------------------
Compiling kernel 100 times WITHOUT PCH...
  Completed 10/100 compilations
  Completed 20/100 compilations
  Completed 30/100 compilations
  Completed 40/100 compilations
  Completed 50/100 compilations
  Completed 60/100 compilations
  Completed 70/100 compilations
  Completed 80/100 compilations
  Completed 90/100 compilations
  Completed 100/100 compilations
Average: 270.58 ms (±6.99 ms)
Min: 264.09 ms
Max: 337.99 ms

Testing WITH PCH:
------------------------------
Compiling kernel 100 times WITH PCH...
  Completed 10/100 compilations
  Completed 20/100 compilations
  Completed 30/100 compilations
  Completed 40/100 compilations
  Completed 50/100 compilations
  Completed 60/100 compilations
  Completed 70/100 compilations
  Completed 80/100 compilations
  Completed 90/100 compilations
  Completed 100/100 compilations
Average: 115.27 ms (±27.32 ms)
Min: 110.65 ms
Max: 383.82 ms

```

## Benchmarking script

```python
#!/usr/bin/env python3
import argparse
import os
import sys
import time
from statistics import mean, stdev

import torch
from torch.cuda._utils import _nvrtc_compile

def benchmark_compilation(use_pch, iterations=100):
    """Compile the same kernel many times with or without PCH."""

    # CUB kernel that benefits from PCH
    kernel_source = """
    #include <cub/block/block_reduce.cuh>
    #include <cub/block/block_scan.cuh>
    #include <cub/warp/warp_reduce.cuh>

    extern "C"
    __global__ void test_kernel(const float* input, float* output, int n) {
        using BlockReduce = cub::BlockReduce<float, 256>;
        using BlockScan = cub::BlockScan<float, 256>;
        using WarpReduce = cub::WarpReduce<float>;

        __shared__ union {
            typename BlockReduce::TempStorage reduce;
            typename BlockScan::TempStorage scan;
            typename WarpReduce::TempStorage warp[8];
        } temp_storage;

        int idx = blockIdx.x * blockDim.x + threadIdx.x;
        float val = (idx < n) ? input[idx] : 0.0f;

        float sum = BlockReduce(temp_storage.reduce).Sum(val);
        __syncthreads();

        float scan_result;
        BlockScan(temp_storage.scan).ExclusiveSum(val, scan_result);
        __syncthreads();

        int warp_id = threadIdx.x / 32;
        float warp_sum = WarpReduce(temp_storage.warp[warp_id]).Sum(val);

        if (threadIdx.x == 0) {
            output[blockIdx.x] = sum + scan_result + warp_sum;
        }
    }
    """

    device = torch.cuda.current_device()
    major, minor = torch.cuda.get_device_capability(device)
    compute_capability = f"{major}{minor}"

    compile_times = []

    print(
        f"Compiling kernel {iterations} times {'WITH' if use_pch else 'WITHOUT'} PCH..."
    )

    for i in range(iterations):
        # Use unique kernel name to avoid caching between iterations
        kernel_name = f"test_kernel_{i}"
        unique_source = kernel_source.replace("test_kernel", kernel_name)

        start = time.perf_counter()

        ptx, mangled_name = _nvrtc_compile(
            unique_source,
            kernel_name,
            compute_capability,
            header_code="",
            nvcc_options=["-std=c++17"],
            auto_pch=use_pch,
        )

        elapsed = time.perf_counter() - start
        compile_times.append(elapsed * 1000)  # Convert to ms

        # Progress indicator
        if (i + 1) % 10 == 0:
            print(f"  Completed {i + 1}/{iterations} compilations")

    return compile_times

def main():
    parser = argparse.ArgumentParser(description="Simple PCH Compilation Benchmark")
    parser.add_argument("--pch", action="store_true", help="Test with PCH only")
    parser.add_argument("--no-pch", action="store_true", help="Test without PCH only")
    parser.add_argument(
        "--iterations", type=int, default=100, help="Number of compilations"
    )
    args = parser.parse_args()

    print("=" * 60)
    print("Simple PCH Compilation Benchmark")
    print("=" * 60)
    print(f"Device: {torch.cuda.get_device_name()}")
    print(f"Iterations: {args.iterations}")
    print()

    # Determine what to test
    test_both = not args.pch and not args.no_pch

    results = {}

    # Test without PCH
    if args.no_pch or test_both:
        print("Testing WITHOUT PCH:")
        print("-" * 30)
        times_no_pch = benchmark_compilation(use_pch=False, iterations=args.iterations)

        if times_no_pch:
            avg_no_pch = mean(times_no_pch)
            std_no_pch = stdev(times_no_pch) if len(times_no_pch) > 1 else 0
            print(f"Average: {avg_no_pch:.2f} ms (±{std_no_pch:.2f} ms)")
            print(f"Min: {min(times_no_pch):.2f} ms")
            print(f"Max: {max(times_no_pch):.2f} ms")
            results["no_pch"] = avg_no_pch
        print()

    # Test with PCH
    if args.pch or test_both:
        print("Testing WITH PCH:")
        print("-" * 30)
        times_with_pch = benchmark_compilation(
            use_pch=True, iterations=args.iterations
        )

        if times_with_pch:
            avg_with_pch = mean(times_with_pch)
            std_with_pch = stdev(times_with_pch) if len(times_with_pch) > 1 else 0
            print(f"Average: {avg_with_pch:.2f} ms (±{std_with_pch:.2f} ms)")
            print(f"Min: {min(times_with_pch):.2f} ms")
            print(f"Max: {max(times_with_pch):.2f} ms")
            results["pch"] = avg_with_pch
        print()

if __name__ == "__main__":
    main()

```

Pull Request resolved: pytorch#162972
Approved by: https://github.com/albanD, https://github.com/janeyx99
@github-actions github-actions bot deleted the msaroufim/cub branch October 16, 2025 02:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: cuda release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants