|
53 | 53 | #endif |
54 | 54 |
|
55 | 55 | #if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && (NCCL_MINOR >= 17) |
56 | | -#define ENABLE_NCCL_RANK_CONFIG |
| 56 | +#define NCCL_HAS_COMM_CTA_CGA |
57 | 57 | #elif defined(NCCL_MAJOR) && (NCCL_MAJOR >= 3) |
58 | | -#define ENABLE_NCCL_RANK_CONFIG |
| 58 | +#define NCCL_HAS_COMM_CTA_CGA |
59 | 59 | #endif |
60 | 60 |
|
61 | 61 | // Macro to throw on a non-successful NCCL return value. |
@@ -185,31 +185,28 @@ class NCCLComm { |
185 | 185 | int rank, |
186 | 186 | ncclUniqueId commId) { |
187 | 187 | auto comm = std::make_shared<NCCLComm>(); |
188 | | -#ifndef NCCL_HAS_COMM_NONBLOCKING |
189 | 188 | C10D_NCCL_CHECK( |
190 | 189 | ncclCommInitRank(&(comm->ncclComm_), numRanks, commId, rank), c10::nullopt); |
191 | | -#else |
192 | | - ncclConfig_t config = NCCL_CONFIG_INITIALIZER; |
193 | | - if (nccl_use_nonblocking()) { |
194 | | - config.blocking = 0; |
195 | | - } |
196 | | - C10D_NCCL_CHECK_TIMEOUT( |
197 | | - ncclCommInitRankConfig(&(comm->ncclComm_), numRanks, commId, rank, &config), comm->ncclComm_, c10::nullopt); |
198 | | -#endif |
199 | 190 | comm->ncclId_ = commId; |
200 | 191 | comm->rank_ = rank; |
201 | 192 | return comm; |
202 | 193 | } |
203 | 194 |
|
204 | | -#ifdef ENABLE_NCCL_RANK_CONFIG |
| 195 | +#ifdef NCCL_HAS_COMM_NONBLOCKING |
205 | 196 | static std::shared_ptr<NCCLComm> create( |
206 | 197 | int numRanks, |
207 | 198 | int rank, |
208 | 199 | ncclUniqueId commId, |
209 | 200 | ncclConfig_t& config) { |
210 | 201 | auto comm = std::make_shared<NCCLComm>(); |
211 | | - C10D_NCCL_CHECK( |
| 202 | + if (nccl_use_nonblocking()) { |
| 203 | + config.blocking = 0; |
| 204 | + C10D_NCCL_CHECK_TIMEOUT( |
| 205 | + ncclCommInitRankConfig(&(comm->ncclComm_), numRanks, commId, rank, &config), comm->ncclComm_, c10::nullopt); |
| 206 | + else { |
| 207 | + C10D_NCCL_CHECK( |
212 | 208 | ncclCommInitRankConfig(&(comm->ncclComm_), numRanks, commId, rank, &config), c10::nullopt); |
| 209 | + } |
213 | 210 | comm->ncclId_ = commId; |
214 | 211 | comm->rank_ = rank; |
215 | 212 | return comm; |
|
0 commit comments