Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 18 additions & 11 deletions .github/workflows/build_kernel.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,16 @@ on:

jobs:
build:
name: Build kernels
name: Build kernels (${{ matrix.arch }})
strategy:
matrix:
include:
- arch: x86_64-linux
runner: aws-highmemory-32-plus-nix
- arch: aarch64-linux
runner: aws-r8g-8xl-plus-nix
runs-on:
group: aws-highmemory-32-plus-nix
group: ${{ matrix.runner }}
steps:
- uses: actions/checkout@v6
- uses: DeterminateSystems/nix-installer-action@main
Expand All @@ -30,27 +37,27 @@ jobs:
run: nix-shell -p nix-info --run "nix-info -m"

- name: Build relu kernel
run: ( cd builder/examples/relu && nix build .\#redistributable.torch29-cxx11-cu126-x86_64-linux )
run: ( cd builder/examples/relu && nix build .\#redistributable.torch29-cxx11-cu126-${{ matrix.arch }} )
- name: Copy relu kernel
run: cp -rL builder/examples/relu/result relu-kernel

- name: Build extra-data kernel
run: ( cd builder/examples/extra-data && nix build .\#redistributable.torch29-cxx11-cu126-x86_64-linux )
run: ( cd builder/examples/extra-data && nix build .\#redistributable.torch29-cxx11-cu126-${{ matrix.arch }} )
- name: Copy extra-data kernel
run: cp -rL builder/examples/extra-data/result extra-data

- name: Build relu kernel (CPU)
run: ( cd builder/examples/relu && nix build .\#redistributable.torch29-cxx11-cpu-x86_64-linux )
run: ( cd builder/examples/relu && nix build .\#redistributable.torch29-cxx11-cpu-${{ matrix.arch }} )
- name: Copy relu kernel (CPU)
run: cp -rL builder/examples/relu/result relu-kernel-cpu

- name: Build cutlass GEMM kernel
run: ( cd builder/examples/cutlass-gemm && nix build .\#redistributable.torch29-cxx11-cu126-x86_64-linux )
run: ( cd builder/examples/cutlass-gemm && nix build .\#redistributable.torch29-cxx11-cu126-${{ matrix.arch }} )
- name: Copy cutlass GEMM kernel
run: cp -rL builder/examples/cutlass-gemm/result cutlass-gemm-kernel

- name: Build relu-backprop-compile kernel
run: ( cd builder/examples/relu-backprop-compile && nix build .\#redistributable.torch29-cxx11-cu126-x86_64-linux )
run: ( cd builder/examples/relu-backprop-compile && nix build .\#redistributable.torch29-cxx11-cu126-${{ matrix.arch }} )
- name: Copy relu-backprop-compile kernel
run: cp -rL builder/examples/relu-backprop-compile/result relu-backprop-compile-kernel

Expand All @@ -59,10 +66,10 @@ jobs:
run: ( cd builder/examples/relu-specific-torch && nix build . )

- name: Build relu kernel (compiler flags)
run: ( cd builder/examples/relu-compiler-flags && nix build .\#redistributable.torch29-cxx11-cu126-x86_64-linux )
run: ( cd builder/examples/relu-compiler-flags && nix build .\#redistributable.torch29-cxx11-cu126-${{ matrix.arch }} )

- name: Test that we can build a test shell (e.g. that gcc corresponds to CUDA-required)
run: ( cd builder/examples/relu && nix build .#devShells.x86_64-linux.test )
run: ( cd builder/examples/relu && nix build .#devShells.${{ matrix.arch }}.test )

- name: Build silu-and-mul kernel
run: ( cd builder/examples/silu-and-mul && nix build .\#redistributable.torch-cuda )
Expand All @@ -72,7 +79,7 @@ jobs:
- name: Upload kernel artifacts
uses: actions/upload-artifact@v6
with:
name: built-kernels
name: built-kernels-${{ matrix.arch }}
path: |
activation-kernel
cutlass-gemm-kernel
Expand All @@ -93,7 +100,7 @@ jobs:
- name: Download kernel artifacts
uses: actions/download-artifact@v7
with:
name: built-kernels
name: built-kernels-x86_64-linux
path: .

- name: Set up Docker Buildx
Expand Down
15 changes: 12 additions & 3 deletions nix/pkgs/python-modules/cuda-bindings/default.nix
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ let
let
cuda_12 = {
version = "12.9.4";
hash = "sha256-Mr3Fp2kGvkxh65j1RqZ4bFdzqIHzsWZIZEm10UHko58=";
hash = {
x86_64-linux = "sha256-Mr3Fp2kGvkxh65j1RqZ4bFdzqIHzsWZIZEm10UHko58=";
aarch64-linux = "sha256-z4v67cI487EV2VfR/WVit+hDW6V/bQ4vh9DnFJzLLaU=";
};
};
in
{
Expand All @@ -27,14 +30,20 @@ let
"12.9" = cuda_12;
"13.0" = {
version = "13.0.3";
hash = "sha256-US0NgDpeR6ikLVo0zgkygCv3L+lS/bEax5hxWjXG5cs=";
hash = {
x86_64-linux = "sha256-US0NgDpeR6ikLVo0zgkygCv3L+lS/bEax5hxWjXG5cs=";
aarch64-linux = "sha256-+xan92nJxnRprdeh2fbBTdRGN/aSHLa564LLUBWzXD0=";
};
};
};

versionHash =
versionHashes.${cudaPackages.cudaMajorMinorVersion}
or (throw "Unsupported CUDA version: ${cudaPackages.cudaMajorMinorVersion}");
inherit (versionHash) hash version;
inherit (versionHash) version;
hash =
versionHash.hash.${stdenv.hostPlatform.system}
or (throw "No hash defined for system: ${stdenv.hostPlatform.system}");

format = "wheel";
pyShortVersion = "cp" + builtins.replaceStrings [ "." ] [ "" ] python.pythonVersion;
Expand Down
5 changes: 5 additions & 0 deletions nix/pkgs/python-modules/torch/binary/torch-versions-hash.json
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,11 @@
"hash": "sha256-vbzHAzgvlI6VHAY0SMlAa/OM5mxB3WmNnicz/PlsA3o=",
"version": "2.10.0"
},
"cu130": {
"url": "https://download.pytorch.org/whl/cu130/torch-2.10.0%2Bcu130-cp313-cp313-manylinux_2_28_aarch64.whl",
"hash": "sha256-dXgCgzCN+f7eNx7toB6WB8iGKhgDovLzGgiiwN6u00I=",
"version": "2.10.0"
},
"cpu": {
"url": "https://download.pytorch.org/whl/cpu/torch-2.10.0%2Bcpu-cp313-cp313-manylinux_2_28_aarch64.whl",
"hash": "sha256-5RmUSSzbdu3OKdqI3jZyowIvnvD/2QNFQ2lI1Jkr4sc=",
Expand Down
2 changes: 1 addition & 1 deletion nix/pkgs/python-modules/torch/binary/torch-versions.json
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
{
"torchVersion": "2.10.0",
"cudaVersion": "13.0",
"systems": ["x86_64-linux"]
"systems": ["x86_64-linux", "aarch64-linux"]
},
{
"torchVersion": "2.10.0",
Expand Down
Loading