Skip to content

Commit 635cced

Browse files
committed
Addresses review
1 parent 1f553e0 commit 635cced

File tree

6 files changed

+14
-42
lines changed

6 files changed

+14
-42
lines changed

test/distributed/test_c10d_nccl.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1292,21 +1292,6 @@ def _test_fp16(self, gradient_as_bucket_view=False):
12921292
def test_fp16(self):
12931293
self._test_fp16()
12941294

1295-
@requires_nccl()
1296-
@requires_nccl_version((2, 17), "Need NCCL 2.17+ for configuring NCCL communicators")
1297-
@skip_if_lt_x_gpu(2)
1298-
def test_ddp_default_cga(self):
1299-
nccl_debug_file = tempfile.NamedTemporaryFile()
1300-
os.environ["NCCL_DEBUG"] = "INFO"
1301-
os.environ["NCCL_DEBUG_FILE"] = nccl_debug_file.name
1302-
1303-
self._test_fp16()
1304-
1305-
# Tests if default CGA for DDP is 2
1306-
nccl_debug_file_content = nccl_debug_file.read()
1307-
cga_cluster_size = re.search(rb'CGA cluster.*(\d+)|$', nccl_debug_file_content).group(1)
1308-
self.assertEqual(int(cga_cluster_size), 2)
1309-
13101295
@requires_nccl()
13111296
@skip_if_lt_x_gpu(2)
13121297
def test_fp16_grad_is_view(self):

torch/csrc/distributed/c10d/NCCLUtils.hpp

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,9 @@
5353
#endif
5454

5555
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && (NCCL_MINOR >= 17)
56-
#define ENABLE_NCCL_RANK_CONFIG
56+
#define NCCL_HAS_COMM_CTA_CGA
5757
#elif defined(NCCL_MAJOR) && (NCCL_MAJOR >= 3)
58-
#define ENABLE_NCCL_RANK_CONFIG
58+
#define NCCL_HAS_COMM_CTA_CGA
5959
#endif
6060

6161
// Macro to throw on a non-successful NCCL return value.
@@ -185,31 +185,28 @@ class NCCLComm {
185185
int rank,
186186
ncclUniqueId commId) {
187187
auto comm = std::make_shared<NCCLComm>();
188-
#ifndef NCCL_HAS_COMM_NONBLOCKING
189188
C10D_NCCL_CHECK(
190189
ncclCommInitRank(&(comm->ncclComm_), numRanks, commId, rank), c10::nullopt);
191-
#else
192-
ncclConfig_t config = NCCL_CONFIG_INITIALIZER;
193-
if (nccl_use_nonblocking()) {
194-
config.blocking = 0;
195-
}
196-
C10D_NCCL_CHECK_TIMEOUT(
197-
ncclCommInitRankConfig(&(comm->ncclComm_), numRanks, commId, rank, &config), comm->ncclComm_, c10::nullopt);
198-
#endif
199190
comm->ncclId_ = commId;
200191
comm->rank_ = rank;
201192
return comm;
202193
}
203194

204-
#ifdef ENABLE_NCCL_RANK_CONFIG
195+
#ifdef NCCL_HAS_COMM_NONBLOCKING
205196
static std::shared_ptr<NCCLComm> create(
206197
int numRanks,
207198
int rank,
208199
ncclUniqueId commId,
209200
ncclConfig_t& config) {
210201
auto comm = std::make_shared<NCCLComm>();
211-
C10D_NCCL_CHECK(
202+
if (nccl_use_nonblocking()) {
203+
config.blocking = 0;
204+
C10D_NCCL_CHECK_TIMEOUT(
205+
ncclCommInitRankConfig(&(comm->ncclComm_), numRanks, commId, rank, &config), comm->ncclComm_, c10::nullopt);
206+
else {
207+
C10D_NCCL_CHECK(
212208
ncclCommInitRankConfig(&(comm->ncclComm_), numRanks, commId, rank, &config), c10::nullopt);
209+
}
213210
comm->ncclId_ = commId;
214211
comm->rank_ = rank;
215212
return comm;

torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1150,7 +1150,7 @@ std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm(
11501150
int deviceIndex = devices[i].index();
11511151

11521152
gpuGuard.set_index(deviceIndex);
1153-
#ifdef ENABLE_NCCL_RANK_CONFIG
1153+
#ifdef NCCL_HAS_COMM_NONBLOCKING
11541154
ncclComms[i] = NCCLComm::create(numRanks, rank, ncclID, options_->config);
11551155
#else
11561156
ncclComms[i] = NCCLComm::create(numRanks, rank, ncclID);

torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ class TORCH_API ProcessGroupNCCL : public Backend {
280280
// Schedule NCCL operations on high priority CUDA streams
281281
bool is_high_priority_stream;
282282

283-
#ifdef ENABLE_NCCL_RANK_CONFIG
283+
#ifdef NCCL_HAS_COMM_NONBLOCKING
284284
// Configure ranks
285285
ncclConfig_t config = NCCL_CONFIG_INITIALIZER;
286286
#endif

torch/csrc/distributed/c10d/init.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1977,7 +1977,7 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
19771977
.def_property_readonly(
19781978
"is_ucc_available", &::c10d::ProcessGroupNCCL::isUCCAvailable);
19791979

1980-
#ifdef ENABLE_NCCL_RANK_CONFIG
1980+
#ifdef NCCL_HAS_COMM_CTA_CGA
19811981
py::class_<ncclConfig_t>(
19821982
processGroupNCCL,
19831983
"NCCLConfig",
@@ -2027,7 +2027,7 @@ Example::
20272027
>>> dist.init_process_group("nccl", pg_options=nccl_options)
20282028
)")
20292029
.def(py::init<bool>(), py::arg("is_high_priority_stream") = false)
2030-
#ifdef ENABLE_NCCL_RANK_CONFIG
2030+
#ifdef NCCL_HAS_COMM_CTA_CGA
20312031
.def_readwrite(
20322032
"is_high_priority_stream",
20332033
&::c10d::ProcessGroupNCCL::Options::is_high_priority_stream)

torch/nn/parallel/distributed.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -749,16 +749,6 @@ def __init__(
749749
else:
750750
self.process_group = process_group
751751

752-
if dist.get_backend(
753-
self.process_group
754-
) == "nccl" and torch.cuda.nccl.version() >= (2, 17):
755-
# Note: NVIDIA recommends using CGA Cluster Size of 2 when using DDP.
756-
default_cga = dist.ProcessGroupNCCL.Options().config.cga_cluster_size # type: ignore[attr-defined]
757-
default_pg_nccl = self.process_group._get_backend(torch.device("cuda"))
758-
current_cga = default_pg_nccl.options.config.cga_cluster_size
759-
if current_cga == default_cga:
760-
default_pg_nccl.options.config.cga_cluster_size = 2
761-
762752
self.static_graph = False
763753
self.dim = dim
764754
self.module = module

0 commit comments

Comments
 (0)