Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 33 additions & 6 deletions test/distributed/test_c10d_nccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import math
import os
import random
import re
import signal
import sys
import tempfile
Expand Down Expand Up @@ -2713,12 +2714,7 @@ def test_sequence_num_set_nccl_new_group(self):
torch.cuda.set_device(self.rank)
self._test_sequence_num_set_new_group(backend="nccl")

@requires_nccl()
@skip_if_lt_x_gpu(2)
def test_pass_nccl_options_high_priority_stream(self):
pg_opts = c10d.ProcessGroupNCCL.Options()
pg_opts.is_high_priority_stream = True

def _test_pass_nccl_options(self, pg_opts):
store = c10d.FileStore(self.file_name, self.world_size)
# Test init_process_group accepts options
dist.init_process_group(
Expand All @@ -2737,6 +2733,37 @@ def test_pass_nccl_options_high_priority_stream(self):
expected_tensor = torch.tensor([3] * 10).cuda(self.rank)
self.assertEqual(expected_tensor, t)

@requires_nccl()
@skip_if_lt_x_gpu(2)
def test_pass_nccl_options_high_priority_stream(self):
pg_opts = c10d.ProcessGroupNCCL.Options()
pg_opts.is_high_priority_stream = True
self._test_pass_nccl_options(pg_opts)

@requires_nccl()
@requires_nccl_version((2, 17), "Need NCCL 2.17+ for configuring NCCL communicators")
@skip_if_lt_x_gpu(2)
def test_pass_nccl_options_config(self):
pg_opts = c10d.ProcessGroupNCCL.Options()
pg_opts.config.max_ctas = 4
pg_opts.config.min_ctas = 2
pg_opts.config.cga_cluster_size = 2
nccl_debug_file = tempfile.NamedTemporaryFile()
os.environ["NCCL_DEBUG"] = "INFO"
os.environ["NCCL_DEBUG_FILE"] = nccl_debug_file.name

# Tests functionality when passing nccl config
self._test_pass_nccl_options(pg_opts)

# Tests if comms were configured
nccl_debug_file_content = nccl_debug_file.read()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests involving file I/O can sometimes be flaky.
I wonder if merely testing if ProcessGroupNCCL is created successfully would suffice.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did see usage of tempfile.NamedTemporaryFile() in test_c10d_nccl.py, and that's why used it here. Is there a different way to do I/O that won't be flaky? I personally wasn't satisfied testing only if ProcessGroupNCCL is created. Creation of ProcessGroupNCCL doesn't necessarily mean it was created with the config values a user may specify in the context of this PR, and AFAIK to figure out if NCCL got those values is through the NCCL_DEBUG file. May be let's wait for the 2.17.1 update to be merged in and I'll rebase this PR, and see if this test is being flaky.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, sounds like a plan.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Latest pipelines test with nccl 2.17.1. Test added in this PR is not flaky.

max_ctas = re.search(rb'Max CTAs.*(\d+)|$', nccl_debug_file_content).group(1)
min_ctas = re.search(rb'Min CTAs.*(\d+)|$', nccl_debug_file_content).group(1)
cga_cluster_size = re.search(rb'CGA cluster.*(\d+)|$', nccl_debug_file_content).group(1)
self.assertEqual(pg_opts.config.max_ctas, int(max_ctas))
self.assertEqual(pg_opts.config.min_ctas, int(min_ctas))
self.assertEqual(pg_opts.config.cga_cluster_size, int(cga_cluster_size))

@requires_nccl()
@skip_if_lt_x_gpu(4)
def test_nccl_barrier(self):
Expand Down
36 changes: 27 additions & 9 deletions torch/csrc/distributed/c10d/NCCLUtils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@
#define ENABLE_NCCL_PREMUL_SUM_SUPPORT
#endif

#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && (NCCL_MINOR >= 17)
#define NCCL_HAS_COMM_CTA_CGA
#elif defined(NCCL_MAJOR) && (NCCL_MAJOR >= 3)
#define NCCL_HAS_COMM_CTA_CGA
#endif

// Macro to throw on a non-successful NCCL return value.
#define C10D_NCCL_CHECK(cmd, failureReason) \
do { \
Expand Down Expand Up @@ -179,22 +185,34 @@ class NCCLComm {
int rank,
ncclUniqueId commId) {
auto comm = std::make_shared<NCCLComm>();
#ifndef NCCL_HAS_COMM_NONBLOCKING
C10D_NCCL_CHECK(
ncclCommInitRank(&(comm->ncclComm_), numRanks, commId, rank), c10::nullopt);
#else
ncclConfig_t config = NCCL_CONFIG_INITIALIZER;
if (nccl_use_nonblocking()) {
config.blocking = 0;
}
C10D_NCCL_CHECK_TIMEOUT(
ncclCommInitRankConfig(&(comm->ncclComm_), numRanks, commId, rank, &config), comm->ncclComm_, c10::nullopt);
#endif
comm->ncclId_ = commId;
comm->rank_ = rank;
return comm;
}

#ifdef NCCL_HAS_COMM_NONBLOCKING
static std::shared_ptr<NCCLComm> create(
int numRanks,
int rank,
ncclUniqueId commId,
ncclConfig_t& config) {
auto comm = std::make_shared<NCCLComm>();
if (nccl_use_nonblocking()) {
config.blocking = 0;
C10D_NCCL_CHECK_TIMEOUT(
ncclCommInitRankConfig(&(comm->ncclComm_), numRanks, commId, rank, &config), comm->ncclComm_, c10::nullopt);
} else {
C10D_NCCL_CHECK(
ncclCommInitRankConfig(&(comm->ncclComm_), numRanks, commId, rank, &config), c10::nullopt);
}
comm->ncclId_ = commId;
comm->rank_ = rank;
return comm;
}
#endif

ncclUniqueId getNcclId() {
return ncclId_;
}
Expand Down
4 changes: 4 additions & 0 deletions torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1156,7 +1156,11 @@ std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm(
int deviceIndex = devices[i].index();

gpuGuard.set_index(deviceIndex);
#ifdef NCCL_HAS_COMM_NONBLOCKING
ncclComms[i] = NCCLComm::create(numRanks, rank, ncclID, options_->config);
#else
ncclComms[i] = NCCLComm::create(numRanks, rank, ncclID);
#endif

// Creates the NCCL streams
streamVal.push_back(
Expand Down
5 changes: 5 additions & 0 deletions torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,11 @@ class TORCH_API ProcessGroupNCCL : public Backend {

// Schedule NCCL operations on high priority CUDA streams
bool is_high_priority_stream;

#ifdef NCCL_HAS_COMM_NONBLOCKING
// Configure ranks
ncclConfig_t config = NCCL_CONFIG_INITIALIZER;
#endif
};

// If you wish to create multiple process groups, each with a potentially
Expand Down
37 changes: 37 additions & 0 deletions torch/csrc/distributed/c10d/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2134,6 +2134,23 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
.def_property_readonly(
"is_ucc_available", &::c10d::ProcessGroupNCCL::isUCCAvailable);

#ifdef NCCL_HAS_COMM_CTA_CGA
py::class_<ncclConfig_t>(
processGroupNCCL,
"NCCLConfig",
R"(
ncclConfig_t data type for configuring NCCL communicators.
See https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig-t
for details.
)")
.def(py::init<>())
.def_readwrite("blocking", &ncclConfig_t::blocking)
.def_readwrite("cga_cluster_size", &ncclConfig_t::cgaClusterSize)
.def_readwrite("min_ctas", &ncclConfig_t::minCTAs)
.def_readwrite("max_ctas", &ncclConfig_t::maxCTAs)
.def_readwrite("net_name", &ncclConfig_t::netName);
Copy link

@ehuaa ehuaa Jun 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a problem remained in pybind11 interface here. The type of netName is const char * not string, if we simply assign net_name with nccl_options.config.net_name = "Socket", there will be an UnicodeDecodeError because of storing a string into a const char * without additional working around it.
Like pybind/pybind11#2337 this issue in pybind11, to assign the net_name correctly here, we should copy the value in string to const char * with a new function, which i want to pull a new request to fix it. @syed-ahmed @kwen2501

#endif

intrusive_ptr_class_<::c10d::ProcessGroupNCCL::Options>(
processGroupNCCL,
"Options",
Expand All @@ -2147,19 +2164,39 @@ ProcessGroup options for the NCCL backend
to prioritize NCCL kernels when there are compute kernels waiting.
Default is False.

Attributes:
config (NCCLConfig): configures NCCL communicators (only avaiable for
builds using NCCL 2.17+). This can be used to improve
communication-computation overlap for NCCL kernels by tuning
available parameters in the config. See
https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig-t
for details.

Example::
>>> import torch.distributed as dist
>>>
>>> nccl_options = dist.ProcessGroupNCCL.Options(is_high_priority_stream=True)
>>> # For builds using NCCL 2.17+, configure communicators
>>> nccl_options.config.cga_cluster_size = 2
>>> nccl_options.config.max_ctas = 4
>>> nccl_options.config.min_ctas = 2
>>> # initialize a nccl process group with the options just created
>>> dist.init_process_group("nccl", pg_options=nccl_options)
)")
.def(py::init<bool>(), py::arg("is_high_priority_stream") = false)
#ifdef NCCL_HAS_COMM_CTA_CGA
.def_readwrite(
"is_high_priority_stream",
&::c10d::ProcessGroupNCCL::Options::is_high_priority_stream)
.def_readwrite("config", &::c10d::ProcessGroupNCCL::Options::config);
#else
.def_readwrite(
"is_high_priority_stream",
&::c10d::ProcessGroupNCCL::Options::is_high_priority_stream);
#endif

#endif

#ifdef USE_C10D_MPI
auto processGroupMPI =
intrusive_ptr_no_gil_destructor_class_<::c10d::ProcessGroupMPI>(
Expand Down