Skip to content

mandroid6/Custom-CUDA-Kernel-Library

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

17 Commits
 
 
 
 
 
 

Repository files navigation

Custom CUDA Kernel Library

Aim of this repo is to build CUDA kernels starting from the most basic to incrementally add advanced features to refresh GPU programming + kernel writing habits.

Environment Setup

Learning Tasks

  1. GPU architecture: SMs, warps, thread blocks, memory hierarchy

    • SMs: streaming multiprocessors; think of it like a "mini-CPU core" but designed for parallel work

      • A GPU has many SMs (eg A100 has 108, H200 as 132)
      • Each SM can run multiple warps (groups of 32 threads) concurrently
      • Each SM has its own:
        • Registers (fastest, private to each thread)
        • Shared memory (fast, shared among threads in a block)
        • L1 cache
        • Warp Scheduler
      • When we launch a CUDA kernel, the thread blocks get assigned to SMs. The SM then schedules warps from that block onto its execution units.
    • Warp: a group of 32 threads that executes together; fundamental unit of execution on NVIDIA GPUs.

      • all 32 threads in a warp execute the same intruction at the same time (SIMT - single instruction multiple threads)
      • you dont create warps directly - when you launch a thread block of say 128 threads, the GPU automatically divides it into 4 warps (128 / 32 = 4)
      • performance impact:
        • if threads within a warp take different code paths (an if/else), both paths execute sequentially - this is called warp divergence and it's slow
        • when threads in a warp access memory, their access can be combined into one transaction if address are contiguous, this is memory coalescing
    • Thread blocks: a group of threads that can cooperate with each other

      • we define the size when launching a kernel (eg 256 threads per block)
      • threads in the same block can:
        • share data through shared memory (fast, on-chip memory)
        • synchronize with __syncthreads() ie wait for all threads in a block to reach that point
      • threads in different blocks cannout communicate directly ie they are independent and may run on diff SMs
      • a thread block runs entirely on one SM, it cannot get split across SMs
      • block size limits: maximum 1024 threads/block on modern GPUs
    • Memory hierarchy: ordered for speed --> registers > shared memory > L1 Cache > L2 cache > Global memory (HBM)

      • registers: fastest; private to each thread; ~255 registers per thread; using too many registers reduces occupancy ie fewer thrreads can run
      • shared memory: very fast (~100x fast than global memory); shared among threads in the same block; size: 228 kb per SM (h200), ~30 MB total smem across chip; ~19+ TB/s bandwidth across all SMs (~10x higher than HBM); managed using __shared__ declaration; tile data for resuse, inter-thread communication; bank conflicts can resultif multiple threads access the same memory bank simultaneously which serialize access
      • L1 cache: fast; per-SM, shares space with shared memory but managed by hardware (automatic caching)
      • L2 cache: medium; shared across all SMs, 40-50 MB on A100/H200, managed by hardware
      • Global memory (HBM): slowest (~10-20x slower than shared mem); accessible by all threads, persists across kernel launches; 141 GB on H200; ~4.8 TB/s bandwidth (H200); best if coalesced access ie adjacent threads access adjacent memory addresses
    • How it fits together

    grid (entire kernel launch)
     |__ Thread Blocks (can cooperate, share memory)
            |__ Warps (32 threads, execute in lockstep)
                    |__ Threads (individual execution units, access register)
    
  2. ncu profiling (nvidia nsight comput)

    • ncu automatically isolates kernel time only; shows each kernel separately in its output
    • useful ncu flags
      // profiling specific kernel by name
      ncu --kernel-name transposeKernel ./program
      // skip first N launches (skip warmup)
      ncu --launch-skip 3 ./program
      
      // profilg only N launches
      ncu --launch-count 1 ./program
      
      // combine: skip 3 warmup, then profile 1 run
      ncu --launch-skip 3 --launch-count 1 ./program
      
  3. Matrix indexing logic (for matrix A of size M x N)

    • Row times width (columns) plus column
    •   A[row][col]  →  A[row * N + col]
                      ↑
               stride = num columns
      
    •   index = row * (number of columns) + col
              = row * N + col
      
    •           Logical 2D:              Linear 1D memory:
                                
        col: 0   1   2   3     index: 0  1  2  3  4  5  6  7  8  9 10 11
        row 0: a   b   c   d     →      [a, b, c, d, e, f, g, h, i, j, k, l]
        row 1: e   f   g   h             ←─row 0─→ ←─row 1─→ ←─row 2─→
        row 2: i   j   k   l
      
    • CUDA mapping logic
        // Standard mapping (works for most cases)
        int row = blockIdx.y * blockDim.y + threadIdx.y;
        int col = blockIdx.x * blockDim.x + threadIdx.x;
      
        // Access element
        A[row * N + col]
      

Open Tasks

  • Perform ncu profiling and check memory/bandwidth variations for vector addition and matrix transpose

Kernels

  1. Vector Addition (kernels/vector_addition.cu)
    • threadsPerBlock: 256 --> 0.19648 ms
    • threadsPerBlock: 512
      • fails since we don't have bounds check inside our kernel, need to add index checks inside vecAdd
      • since threadsPerBlock > N (vector size), we just need 1 block, but many threads will do nothing
      • 0.191392 ms
    • threadsPerBlock: 1024 --> 0.192896 ms
    • overall not much difference in execution time since kernel launch overhead dominates and actual compute is in nanoseconds
    • all three doe the same 256 additions, the extra threads just hit bound checks and exit

  1. Matrix Transpose (kernels/matrix_transpose.cu)
    • dim(A) == [M, N], dim(B) == [N, M]
    • Naive version
      • works correctly if we set threadsPerBlock to be dim3(32, 32) and we got lucky since we only need 1 block because row would already be in range 0..31 and col would also be in range 0..31
      • if its unequal across x, y then the indexing logic messes up ie eg threadsPerBlock = dim3(64, 16) and blocks= dim3(3232/6416) == 1, then every 16th value is incorrect as it kernel cannot see data beyond 0..16 for col
        • looking at the index ranges
          • row would vary from 0..16 --> which means indexes 16..31 are not seen at all since we only have blockDim=1, it doesn't contribute to the row/col index caluculation
          • we either need to increase blockDim across row to see all 32 values or equally distribute it as a 2D across row/col dimension
      • general formula for threadsPerBlock and blocksPerGrid
        dim3 threadsPerBlock(TX, TY);
        dim3 blocksPerGrid(
            ceil(N / (float)TX),   // blocks to cover columns
            ceil(M / (float)TY)    // blocks to cover rows
        );
        
      • CUDA kernel execution time: 0.1856 ms;
      • Key metrics for transpose:
        • memory bandwidth
          • bytes moved = 2 * M * N * sizeof(float) // read A + write B
          • bandwidth (GB/s) = Bytes moved / (time in seconds) / 1e9
      • real numbers (H200 peak bandwidth of 4.8 TB/s)
        • M=32, N=32
          • bytes moved 2 * 32 * 32 * 4 ==> 8192
          • time taken = 0.1856 ms
          • bandwidth = 8192/ (0.1856 * 1e-3) / 1e9 ==> 0.04413 GB/s (extremly low)
          • indicates kernel is memory bound (since h200 has bandwidth of 4.8 TB/s)
        • M=32, N=1024
          • Bandwidth: 1.35 GB/s
          • still memory bound
        • M=1024, N=1024
          • CUDA kernel execution time: 0.0217536 ms
          • Bandwidth: 38.56 GB/s
        • M = 10246, N = 10246
          • CUDA kernel execution time: 0.35392 ms
          • Bandwidth: 853.27 GB/s
        • M = 102410, N = 102410
          • CUDA kernel execution time: 0.619648 ms
          • Bandwidth: 1353.77 GB/s
          • still memory bound
        • M = 1024100, N = 1024100
          • kernel gets stuck in execution (segmentation fault (core dumped) ./a.out)
          • segfault here occurs on the host side malloc because we do not have enough memory (RAM) and we are trying to allocate 84 GB
          • need to manually kill the host process
        • M = 102420, N = 102420
          • CUDA kernel execution time: 1.86445 ms
          • Bandwidth: 1799.70 GB/s
    • Coalesced memory
      • currently the reads from input A are coalesced but the writes to output B are strided by M

      • we call the reads coalesced since adjacent threads read adjacent memory

      • declared tile using __shared__ tile[32][32] declaration

        • with values loaded into tile with the same indexing pattern as reading from A such that adjacent threads access adjacent memory
        • for writing to B, we need to calculate transposed values of out_row, out_col
      • CUDA kernel execution time: 0.012032 ms

      • Bandwidth: 697.19 GB/s

      • Still only at 52% memory throughput; alot to be improveed

      • issue with bank conflicts with current tile_size. When we read tile[threadIdx.x][threadIdx.y] (transposed read):

        tile[32][32] layout in shared memory:

        tile[0][0] → bank 0
        tile[1][0] → bank 0  (offset 32, and 32 % 32 = 0)
        tile[2][0] → bank 0  (offset 64, and 64 % 32 = 0)
        ...
        tile[31][0] → bank 0
        

        All 32 threads reading column 0 hit the same bank!

      • The fix: Pad the shared memory cuda__shared__ float tile[TILE_DIM][TILE_DIM + 1]; // 32 x 33

        Now:

        tile[0][0] → bank 0
        tile[1][0] → bank 1  (offset 33, and 33 % 32 = 1)
        tile[2][0] → bank 2  (offset 66, and 66 % 32 = 2)
        ...
        

        Each thread hits a different bank — no conflicts!

      • with tile[32][33]

        • CUDA kernel execution time: 0.008032 ms
        • Bandwidth: 1044.40 GB/s
        • speed-up: 33%

  1. Matrix Multiplication (GEMM)
    • Naive GEMM A[M, K] * B[K, N] --> C[M, N]

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages