-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Enables configuration of NCCL communicators #97394
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
cd986f0
a86d4c7
5536c35
c45e6a2
858f81d
23959f1
e6fa35b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2134,6 +2134,23 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`). | |
| .def_property_readonly( | ||
| "is_ucc_available", &::c10d::ProcessGroupNCCL::isUCCAvailable); | ||
|
|
||
| #ifdef NCCL_HAS_COMM_CTA_CGA | ||
| py::class_<ncclConfig_t>( | ||
| processGroupNCCL, | ||
| "NCCLConfig", | ||
| R"( | ||
| ncclConfig_t data type for configuring NCCL communicators. | ||
| See https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig-t | ||
| for details. | ||
| )") | ||
| .def(py::init<>()) | ||
| .def_readwrite("blocking", &ncclConfig_t::blocking) | ||
| .def_readwrite("cga_cluster_size", &ncclConfig_t::cgaClusterSize) | ||
| .def_readwrite("min_ctas", &ncclConfig_t::minCTAs) | ||
| .def_readwrite("max_ctas", &ncclConfig_t::maxCTAs) | ||
| .def_readwrite("net_name", &ncclConfig_t::netName); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's a problem remained in pybind11 interface here. The type of netName is const char * not string, if we simply assign net_name with nccl_options.config.net_name = "Socket", there will be an UnicodeDecodeError because of storing a string into a const char * without additional working around it. |
||
| #endif | ||
|
|
||
| intrusive_ptr_class_<::c10d::ProcessGroupNCCL::Options>( | ||
| processGroupNCCL, | ||
| "Options", | ||
|
|
@@ -2147,19 +2164,39 @@ ProcessGroup options for the NCCL backend | |
| to prioritize NCCL kernels when there are compute kernels waiting. | ||
| Default is False. | ||
|
|
||
| Attributes: | ||
| config (NCCLConfig): configures NCCL communicators (only avaiable for | ||
| builds using NCCL 2.17+). This can be used to improve | ||
| communication-computation overlap for NCCL kernels by tuning | ||
| available parameters in the config. See | ||
| https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig-t | ||
| for details. | ||
|
|
||
| Example:: | ||
| >>> import torch.distributed as dist | ||
| >>> | ||
| >>> nccl_options = dist.ProcessGroupNCCL.Options(is_high_priority_stream=True) | ||
| >>> # For builds using NCCL 2.17+, configure communicators | ||
| >>> nccl_options.config.cga_cluster_size = 2 | ||
| >>> nccl_options.config.max_ctas = 4 | ||
| >>> nccl_options.config.min_ctas = 2 | ||
| >>> # initialize a nccl process group with the options just created | ||
| >>> dist.init_process_group("nccl", pg_options=nccl_options) | ||
| )") | ||
| .def(py::init<bool>(), py::arg("is_high_priority_stream") = false) | ||
| #ifdef NCCL_HAS_COMM_CTA_CGA | ||
| .def_readwrite( | ||
| "is_high_priority_stream", | ||
| &::c10d::ProcessGroupNCCL::Options::is_high_priority_stream) | ||
| .def_readwrite("config", &::c10d::ProcessGroupNCCL::Options::config); | ||
| #else | ||
| .def_readwrite( | ||
| "is_high_priority_stream", | ||
| &::c10d::ProcessGroupNCCL::Options::is_high_priority_stream); | ||
| #endif | ||
|
|
||
| #endif | ||
|
|
||
| #ifdef USE_C10D_MPI | ||
| auto processGroupMPI = | ||
| intrusive_ptr_no_gil_destructor_class_<::c10d::ProcessGroupMPI>( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tests involving file I/O can sometimes be flaky.
I wonder if merely testing if ProcessGroupNCCL is created successfully would suffice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did see usage of
tempfile.NamedTemporaryFile()intest_c10d_nccl.py, and that's why used it here. Is there a different way to do I/O that won't be flaky? I personally wasn't satisfied testing only if ProcessGroupNCCL is created. Creation of ProcessGroupNCCL doesn't necessarily mean it was created with the config values a user may specify in the context of this PR, and AFAIK to figure out if NCCL got those values is through the NCCL_DEBUG file. May be let's wait for the 2.17.1 update to be merged in and I'll rebase this PR, and see if this test is being flaky.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, sounds like a plan.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Latest pipelines test with nccl 2.17.1. Test added in this PR is not flaky.