Skip to content

fix(metal): support argsort for arrays >1024 elements#3308

Draft
DumbTechLion wants to merge 2 commits intohuggingface:mainfrom
DumbTechLion:fix/metal-argsort-large-arrays
Draft

fix(metal): support argsort for arrays >1024 elements#3308
DumbTechLion wants to merge 2 commits intohuggingface:mainfrom
DumbTechLion:fix/metal-argsort-large-arrays

Conversation

@DumbTechLion
Copy link

@DumbTechLion DumbTechLion commented Jan 16, 2026

Here is an attempt for #2570. Our ML research project requires it to work. Don't hesitate to comment or contribute in any way.

Summary

  • Route large arrays (>1024 elements) to MLX multi-block merge sort instead of bitonic sort, which was limited by Metal's 1024 thread
    per threadgroup limit
  • Add GreaterThan comparator for descending sort support in MLX kernels
  • Add CUDA shared memory validation with clear error message when exceeding ~12K element limit
  • Enable asort_big test on Metal, add vocabulary-size tests (2048, 4096, 32000 elements)

Fixes #2570

Problem

The Metal argsort kernel used bitonic sort with ncols_pad as threadgroup size. Since Metal limits threadgroups to 1024 threads,
arrays with >1024 elements failed silently.

Solution

For arrays >1024 elements, use the existing call_mlx_arg_sort function which implements multi-block merge sort from MLX. This
handles arbitrary array sizes efficiently.

Supported types for large arrays: BF16, F16, F32, U8, U32, I64

Notes:

  • Small arrays (≤1024) still use bitonic sort for performance
  • F64, I16, I32, F8E4M3 not yet supported for large arrays (will error with clear message)

Test plan

  • asort_cpu - existing test passes
  • asort_metal - existing test passes
  • asort_big_cpu - tests 2000 elements
  • asort_big_metal - now passes (was skipped)
  • asort_vocabulary_cpu - tests 2048, 4096, 32000 elements
  • asort_vocabulary_metal - tests vocabulary sizes used in LLMs
cargo test -p candle-core --test tensor_tests --features metal -- asort

- Route large arrays (>1024) to MLX multi-block merge sort
- Add GreaterThan comparator for descending sort support
- Add CUDA shared memory validation with clear error message
- Enable asort_big test on Metal, add vocabulary-size test

Fixes huggingface#2570

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@DumbTechLion DumbTechLion marked this pull request as draft January 16, 2026 15:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[TRACKING] argsort metal kernel yields incorrect output with > 1024 elements

1 participant