Skip to content

Commit bf3655a

Browse files
authored
make torch.set_num_threads also set MKL threads (take 2) (#5002)
* torch.set_num_threads sets MKL option too * fix to use C prototype instead of fortran
1 parent 86fd5fd commit bf3655a

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

aten/src/TH/THGeneral.c

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,13 @@
2121
#include <malloc/malloc.h>
2222
#endif
2323

24+
#ifdef TH_BLAS_MKL
25+
// this is the C prototype, while mkl_set_num_threads is the fortran prototype
26+
extern void MKL_Set_Num_Threads(int);
27+
// this is the C prototype, while mkl_get_max_threads is the fortran prototype
28+
extern int MKL_Get_Max_Threads(void);
29+
#endif
30+
2431
/* Torch Error Handling */
2532
static void defaultErrorHandlerFunction(const char *msg, void *data)
2633
{
@@ -302,6 +309,10 @@ void THSetNumThreads(int num_threads)
302309
#ifdef _OPENMP
303310
omp_set_num_threads(num_threads);
304311
#endif
312+
#ifdef TH_BLAS_MKL
313+
MKL_Set_Num_Threads(num_threads);
314+
#endif
315+
305316
}
306317

307318
int THGetNumThreads(void)
@@ -322,18 +333,14 @@ int THGetNumCores(void)
322333
#endif
323334
}
324335

325-
#ifdef TH_BLAS_MKL
326-
extern int mkl_get_max_threads(void);
327-
#endif
328-
329336
TH_API void THInferNumThreads(void)
330337
{
331338
#if defined(_OPENMP) && defined(TH_BLAS_MKL)
332339
// If we are using MKL an OpenMP make sure the number of threads match.
333340
// Otherwise, MKL and our OpenMP-enabled functions will keep changing the
334341
// size of the OpenMP thread pool, resulting in worse performance (and memory
335342
// leaks in GCC 5.4)
336-
omp_set_num_threads(mkl_get_max_threads());
343+
omp_set_num_threads(MKL_Get_Max_Threads());
337344
#endif
338345
}
339346

0 commit comments

Comments
 (0)