Skip to content

Commit 2aaeec0

Browse files
authored
torch.set_num_threads sets MKL option too (#4949)
1 parent 3ac412e commit 2aaeec0

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

aten/src/TH/THGeneral.c

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@
2121
#include <malloc/malloc.h>
2222
#endif
2323

24+
#ifdef TH_BLAS_MKL
25+
extern void mkl_set_num_threads(int);
26+
extern int mkl_get_max_threads(void);
27+
#endif
28+
2429
/* Torch Error Handling */
2530
static void defaultErrorHandlerFunction(const char *msg, void *data)
2631
{
@@ -302,6 +307,10 @@ void THSetNumThreads(int num_threads)
302307
#ifdef _OPENMP
303308
omp_set_num_threads(num_threads);
304309
#endif
310+
#ifdef TH_BLAS_MKL
311+
mkl_set_num_threads(num_threads);
312+
#endif
313+
305314
}
306315

307316
int THGetNumThreads(void)
@@ -322,10 +331,6 @@ int THGetNumCores(void)
322331
#endif
323332
}
324333

325-
#ifdef TH_BLAS_MKL
326-
extern int mkl_get_max_threads(void);
327-
#endif
328-
329334
TH_API void THInferNumThreads(void)
330335
{
331336
#if defined(_OPENMP) && defined(TH_BLAS_MKL)

0 commit comments

Comments
 (0)