Skip to content

Commit ebb2dcd

Browse files
committed
Update on "[ContextParallel] add process-time based Round-Robin load-balance to CP"
**Summary** The load-balancing problem can be modeled as [identical-machines scheduling](https://en.wikipedia.org/wiki/Identical-machines_scheduling) problem. We already provided an easy-to-extend interface in #161062 for implementing load-balancing and in this PR we start with adding a Round-Robin solution as an example and also a verification. This can be easily adapted to other solutions like Shortest-processing-time-first/ Longest-processing-time-first with extra padding added for collectives. - Added a new type of `_LoadBalancer` implementation `_PTRRLoadBalancer` which is designed for `flex_attention()`. This load-balance strategy analyzes the `BlockMask` sparsity info and perform Round-Robin (unlike traditional Round-Robin doing it in circular order, we do in zig-zag order). - Make `_context_parallel_buffers` and `context_parallel_unshard` handle batched load-balance index (previously it can only handle non-batched load-balance index), like in `create_cp_block_mask`. **Test** `pytest test/distributed/tensor/test_attention.py` cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci [ghstack-poisoned]
2 parents d63a34c + 8e3f28a commit ebb2dcd

File tree

269 files changed

+5589
-1987
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

269 files changed

+5589
-1987
lines changed

.ci/manywheel/build_cuda.sh

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,19 +187,22 @@ if [[ $CUDA_VERSION == 12* || $CUDA_VERSION == 13* ]]; then
187187
export USE_CUFILE=0
188188
else
189189
DEPS_LIST+=(
190-
"/usr/local/cuda/lib64/libnvToolsExt.so.1"
191190
"/usr/local/cuda/lib64/libcublas.so.12"
192191
"/usr/local/cuda/lib64/libcublasLt.so.12"
193192
"/usr/local/cuda/lib64/libcudart.so.12"
194193
"/usr/local/cuda/lib64/libnvrtc.so.12"
195194
"/usr/local/cuda/extras/CUPTI/lib64/libcupti.so.12")
196195
DEPS_SONAME+=(
197-
"libnvToolsExt.so.1"
198196
"libcublas.so.12"
199197
"libcublasLt.so.12"
200198
"libcudart.so.12"
201199
"libnvrtc.so.12"
202200
"libcupti.so.12")
201+
202+
if [[ $CUDA_VERSION != 12.9* ]]; then
203+
DEPS_LIST+=("/usr/local/cuda/lib64/libnvToolsExt.so.1")
204+
DEPS_SONAME+=("libnvToolsExt.so.1")
205+
fi
203206
fi
204207
else
205208
echo "Using nvidia libs from pypi."

.github/ci_commit_pins/audio.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
8ad2aa5d354d1bf432339113860185d5a5d1abbd
1+
1b013f5b5a87a1882eb143c26d79d091150d6a37

.github/ci_commit_pins/vision.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
f5c6c2ec6490455e86f67b2a25c10390d60a27f7
1+
faffd5cf673615583da6517275e361cb3dbc77e6

aten/src/ATen/CMakeLists.txt

Lines changed: 52 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,7 @@ endif()
256256
IF(USE_FBGEMM_GENAI)
257257
set(FBGEMM_THIRD_PARTY ${PROJECT_SOURCE_DIR}/third_party/fbgemm/external/)
258258
set(FBGEMM_GENAI_SRCS ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize)
259+
259260
if(USE_CUDA)
260261
# To avoid increasing the build time/binary size unnecessarily, use an allow-list of kernels to build.
261262
# If you want to integrate a kernel from FBGEMM into torch, you have to add it here.
@@ -292,58 +293,64 @@ IF(USE_FBGEMM_GENAI)
292293
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/mx8mx8bf16_grouped/"
293294
)
294295

295-
target_include_directories(fbgemm_genai PUBLIC
296+
target_include_directories(fbgemm_genai PRIVATE
296297
${FBGEMM_THIRD_PARTY}/cutlass/include
297298
${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include
298299
${fbgemm_genai_mx8mx8bf16_grouped}
299300
${FBGEMM_GENAI_SRCS}/common/include/ # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp
300301
${FBGEMM_GENAI_SRCS}/include/ # includes fbgemm_gpu/torch_ops.h
301302
)
302-
else()
303-
if(USE_ROCM)
304-
# Only include the kernels we want to build to avoid increasing binary size.
305-
file(GLOB_RECURSE fbgemm_genai_native_rocm_hip
306-
"${FBGEMM_GENAI_SRCS}/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped*.hip"
307-
"${FBGEMM_GENAI_SRCS}/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip")
308-
set_source_files_properties(${fbgemm_genai_native_rocm_hip} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
309-
310-
# Add additional HIPCC compiler flags for performance
311-
set(FBGEMM_GENAI_EXTRA_HIPCC_FLAGS
312-
-mllvm
313-
-amdgpu-coerce-illegal-types=1
314-
-mllvm
315-
-enable-post-misched=0
316-
-mllvm
317-
-greedy-reverse-local-assignment=1
318-
-fhip-new-launch-api)
319-
320-
# Only compile for gfx942 for now.
321-
# This is rather hacky, I could not figure out a clean solution :(
322-
set(HIP_CLANG_FLAGS_ORIGINAL ${HIP_CLANG_FLAGS})
323-
string(REGEX REPLACE "--offload-arch=[^ ]*" "" FILTERED_HIP_CLANG_FLAGS "${HIP_CLANG_FLAGS}")
324-
if("gfx942" IN_LIST PYTORCH_ROCM_ARCH)
325-
list(APPEND FILTERED_HIP_CLANG_FLAGS --offload-arch=gfx942;)
326-
endif()
327-
set(HIP_CLANG_FLAGS ${FILTERED_HIP_CLANG_FLAGS})
328-
329-
hip_add_library(
330-
fbgemm_genai STATIC
331-
${fbgemm_genai_native_rocm_hip}
332-
HIPCC_OPTIONS ${HIP_HCC_FLAGS} ${FBGEMM_GENAI_EXTRA_HIPCC_FLAGS})
333-
set(HIP_CLANG_FLAGS ${HIP_CLANG_FLAGS_ORIGINAL})
334-
set_target_properties(fbgemm_genai PROPERTIES POSITION_INDEPENDENT_CODE ON)
335-
target_compile_definitions(fbgemm_genai PRIVATE FBGEMM_GENAI_NO_EXTENDED_SHAPES)
336-
337-
target_include_directories(fbgemm_genai PUBLIC
338-
# FBGEMM version of Composable Kernel is used due to some customizations
339-
${FBGEMM_THIRD_PARTY}/composable_kernel/include
340-
${FBGEMM_THIRD_PARTY}/composable_kernel/library/include
341-
${FBGEMM_THIRD_PARTY}/cutlass/include
342-
${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include
343-
${FBGEMM_GENAI_SRCS}/common/include/ # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp
344-
${FBGEMM_GENAI_SRCS}/include/ # includes fbgemm_gpu/torch_ops.h
345-
)
303+
304+
# Add FBGEMM_GENAI include directories for torch_ops.h
305+
list(APPEND ATen_CUDA_INCLUDE ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/include)
306+
list(APPEND ATen_CUDA_INCLUDE ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/common/include)
307+
elseif(USE_ROCM)
308+
# Only include the kernels we want to build to avoid increasing binary size.
309+
file(GLOB_RECURSE fbgemm_genai_native_rocm_hip
310+
"${FBGEMM_GENAI_SRCS}/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped*.hip"
311+
"${FBGEMM_GENAI_SRCS}/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip")
312+
set_source_files_properties(${fbgemm_genai_native_rocm_hip} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
313+
314+
# Add additional HIPCC compiler flags for performance
315+
set(FBGEMM_GENAI_EXTRA_HIPCC_FLAGS
316+
-mllvm
317+
-amdgpu-coerce-illegal-types=1
318+
-mllvm
319+
-enable-post-misched=0
320+
-mllvm
321+
-greedy-reverse-local-assignment=1
322+
-fhip-new-launch-api)
323+
324+
# Only compile for gfx942 for now.
325+
# This is rather hacky, I could not figure out a clean solution :(
326+
set(HIP_CLANG_FLAGS_ORIGINAL ${HIP_CLANG_FLAGS})
327+
string(REGEX REPLACE "--offload-arch=[^ ]*" "" FILTERED_HIP_CLANG_FLAGS "${HIP_CLANG_FLAGS}")
328+
if("gfx942" IN_LIST PYTORCH_ROCM_ARCH)
329+
list(APPEND FILTERED_HIP_CLANG_FLAGS --offload-arch=gfx942;)
346330
endif()
331+
set(HIP_CLANG_FLAGS ${FILTERED_HIP_CLANG_FLAGS})
332+
333+
hip_add_library(
334+
fbgemm_genai STATIC
335+
${fbgemm_genai_native_rocm_hip}
336+
HIPCC_OPTIONS ${HIP_HCC_FLAGS} ${FBGEMM_GENAI_EXTRA_HIPCC_FLAGS})
337+
set(HIP_CLANG_FLAGS ${HIP_CLANG_FLAGS_ORIGINAL})
338+
set_target_properties(fbgemm_genai PROPERTIES POSITION_INDEPENDENT_CODE ON)
339+
target_compile_definitions(fbgemm_genai PRIVATE FBGEMM_GENAI_NO_EXTENDED_SHAPES)
340+
341+
target_include_directories(fbgemm_genai PRIVATE
342+
# FBGEMM version of Composable Kernel is used due to some customizations
343+
${FBGEMM_THIRD_PARTY}/composable_kernel/include
344+
${FBGEMM_THIRD_PARTY}/composable_kernel/library/include
345+
${FBGEMM_THIRD_PARTY}/cutlass/include
346+
${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include
347+
${FBGEMM_GENAI_SRCS}/common/include/ # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp
348+
${FBGEMM_GENAI_SRCS}/include/ # includes fbgemm_gpu/torch_ops.h
349+
)
350+
351+
# Add FBGEMM_GENAI include directories for torch_ops.h
352+
list(APPEND ATen_HIP_INCLUDE ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/include)
353+
list(APPEND ATen_HIP_INCLUDE ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/common/include)
347354
endif()
348355
endif()
349356

@@ -692,12 +699,6 @@ if(USE_CUDA AND NOT USE_ROCM)
692699
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/include)
693700
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/tools/util/include)
694701

695-
# Add FBGEMM_GENAI include directories for torch_ops.h
696-
if(USE_FBGEMM_GENAI)
697-
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/include)
698-
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/common/include)
699-
endif()
700-
701702
if($ENV{ATEN_STATIC_CUDA})
702703
if(CUDA_VERSION VERSION_LESS_EQUAL 12.9)
703704
list(APPEND ATen_CUDA_DEPENDENCY_LIBS

aten/src/ATen/cuda/tunable/GemmCommon.h

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <c10/core/ScalarType.h>
1414

1515
#include <ATen/cuda/tunable/TunableOp.h>
16+
#include <ATen/cuda/tunable/Tunable.h>
1617
#include <ATen/cuda/CUDABlas.h>
1718
#include <ATen/cuda/Exceptions.h>
1819
#include <c10/util/StringUtil.h>
@@ -150,6 +151,7 @@ inline std::string ScalarTypeToBLASType(c10::ScalarType scalar_type) {
150151
BLASType = "unknown";
151152
}
152153
return BLASType;
154+
153155
}
154156

155157
// Similar to Compute Type in GemmRocblas.h
@@ -244,33 +246,25 @@ inline std::string to_string_epilogue(const at::cuda::blas::GEMMAndBiasActivatio
244246

245247
namespace detail {
246248

247-
static bool NumericalCheck(ScalarType dtype, void* c, void* other_c, int64_t size) {
249+
static bool NumericalCheck(ScalarType dtype, void* c, void* other_c, int64_t size, const NumericalCheckConfig& config) {
250+
251+
if (!config.enabled) {
252+
return true; // skip when disabled
253+
}
254+
248255
auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA);
249-
// comparison done as 1D tensor
250256
at::Tensor ref = at::from_blob(c, {size}, options);
251257
at::Tensor oth = at::from_blob(other_c, {size}, options);
252258
at::Tensor ref_float = ref.to(at::kFloat);
253259
at::Tensor oth_float = oth.to(at::kFloat);
254-
std::vector<double> atols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
255-
std::vector<double> rtols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
256-
double last_succeed_atol = 1;
257-
double last_succeed_rtol = 1;
258-
for (auto& atol : atols) {
259-
for (auto& rtol : rtols) {
260-
if (at::allclose(ref_float, oth_float, rtol, atol)) {
261-
last_succeed_atol = atol;
262-
last_succeed_rtol = rtol;
263-
}
264-
}
265-
}
266-
if (last_succeed_atol == 1) {
267-
return false;
268-
}
269-
else {
270-
TUNABLE_LOG3("├──verify numerics: atol=", last_succeed_atol, ", rtol=", last_succeed_rtol);
271-
}
272260

273-
return true;
261+
const bool ok = at::allclose(ref_float, oth_float, config.rtol, config.atol);
262+
if (ok) {
263+
TUNABLE_LOG3("├──verify numerics: PASSED with atol=", config.atol, ", rtol=", config.rtol);
264+
} else {
265+
TUNABLE_LOG3("├──verify numerics: FAILED with atol=", config.atol, ", rtol=", config.rtol);
266+
}
267+
return ok;
274268
}
275269

276270
}
@@ -355,8 +349,10 @@ struct GemmParams : OpParams {
355349
}
356350

357351
TuningStatus NumericalCheck(GemmParams<T> *other) {
352+
auto* ctx = getTuningContext();
353+
auto cfg = ctx->GetNumericalCheckConfig();
358354
auto c_dtype = c10::CppTypeToScalarType<T>::value;
359-
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL;
355+
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
360356
}
361357

362358
char transa{};
@@ -449,8 +445,10 @@ struct GemmAndBiasParams : OpParams {
449445
}
450446

451447
TuningStatus NumericalCheck(GemmAndBiasParams<T> *other) {
448+
auto* ctx = getTuningContext();
449+
auto cfg = ctx->GetNumericalCheckConfig();
452450
auto c_dtype = c10::CppTypeToScalarType<T>::value;
453-
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL;
451+
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
454452
}
455453

456454
char transa{};
@@ -546,8 +544,10 @@ struct GemmStridedBatchedParams : OpParams {
546544
}
547545

548546
TuningStatus NumericalCheck(GemmStridedBatchedParams<T> *other) {
547+
auto* ctx = getTuningContext();
548+
auto cfg = ctx->GetNumericalCheckConfig();
549549
auto c_dtype = c10::CppTypeToScalarType<C_Dtype>::value;
550-
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL;
550+
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
551551
}
552552

553553
char transa{};
@@ -663,7 +663,9 @@ struct ScaledGemmParams : OpParams {
663663
}
664664

665665
TuningStatus NumericalCheck(ScaledGemmParams<T> *other) {
666-
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL;
666+
auto* ctx = getTuningContext();
667+
auto cfg = ctx->GetNumericalCheckConfig();
668+
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
667669
}
668670

669671
char transa{};

aten/src/ATen/cuda/tunable/README.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ programmatically since the settings become fixed. Use the C++ or Python APIs ins
145145
| PYTORCH_TUNABLEOP_VERBOSE | Default is 0. Set to 1 to enable basic logging. 2 for basic tuning status. 3 for full trace. |
146146
| PYTORCH_TUNABLEOP_VERBOSE_FILENAME | Default is "err" for stderr. Set to "out" for stdout or a filename for capturing verbose logging. |
147147
| PYTORCH_TUNABLEOP_FILENAME | Default is 'tunableop_results.csv'. |
148-
| PYTORCH_TUNABLEOP_NUMERICAL_CHECK | Default is 0. Set to 1 to enable. |
148+
| PYTORCH_TUNABLEOP_NUMERICAL_CHECK | Default is off. Set 'atol_rtol' to enable, for example "1e-5_1e-5". |
149149
| PYTORCH_TUNABLEOP_ROCBLAS_ENABLED | Default is 1. Set to 0 to disable rocblas being considered during tuning. |
150150
| PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED | Default is 1. Set to 0 to disable hipblaslt being considered during tuning. |
151151
| PYTORCH_TUNABLEOP_MAX_TUNING_DURATION_MS | Default is 30. Unit is milliseconds. |
@@ -173,10 +173,9 @@ All python APIs exist in the `torch.cuda.tunable` module.
173173
| get_max_tuning_iterations() -> int | |
174174
| set_filename(filename: str, insert_device_ordinal: bool = False) -> None | |
175175
| get_filename() -> str | |
176+
| set_numerical_check_tolerances(enable: bool, atol: float, rtol: float) -> None | Enable or disable numerical checking; atol and rtol default to 1e-5.
176177
| get_results() -> Tuple[str, str, str, float] | |
177178
| get_validators() -> Tuple[str, str] | |
178-
| write_file_on_exit(val: bool) -> None | Default is True. |
179-
| write_file(filename: Optional[str] = None) -> None | If filename not given, it will call get_filename(). |
180179
| read_file(filename: Optional[str] = None) -> None | If filename not given, it will call get_filename(). |
181180
| tune_gemm_in_file(filename: str) -> None | read an untuned file and tune GEMMs in it. |
182181
| mgpu_tune_gemm_in_file(filename_pattern: str, num_gpus: int) -> None: -> None | read one or more untuned files and tune all unique GEMMs on one or more GPUs. |

0 commit comments

Comments
 (0)