Skip to content

Commit 86fdc89

Browse files
colesburysoumith
authored andcommitted
Don't throw exceptions inside OpenMP parallel blocks (#4857)
Fixes undefined behavior: exceptions are not allowed to be thrown across OpenMP constructs.
1 parent 548596b commit 86fdc89

File tree

1 file changed

+29
-16
lines changed

1 file changed

+29
-16
lines changed

torch/lib/TH/generic/THTensorMath.c

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -361,8 +361,11 @@ static ptrdiff_t THTensor_(dataOffset)(THTensor* tensor, ptrdiff_t linearIndex)
361361
return dataOffset;
362362
}
363363

364-
static int64_t THTensor_(wrapLinearIndex)(int64_t linearIndex, int64_t numel) {
364+
static void THTensor_(checkLinearIndex)(int64_t linearIndex, int64_t numel) {
365365
THArgCheck(linearIndex < numel && linearIndex >= -numel, 2, "out of range: %d out of %d", (int)linearIndex, (int)numel);
366+
}
367+
368+
static int64_t THTensor_(wrapLinearIndex)(int64_t linearIndex, int64_t numel) {
366369
return linearIndex < 0 ? linearIndex + numel : linearIndex;
367370
}
368371

@@ -376,25 +379,34 @@ void THTensor_(take)(THTensor *r_, THTensor *src, THLongTensor *index)
376379
ptrdiff_t srcElements = THTensor_(nElement)(src);
377380
real* src_data = THTensor_(data)(src);
378381
real* dst_data = THTensor_(data)(dst);
379-
380382
ptrdiff_t nIndices = THLongTensor_nElement(index);
381-
if (THTensor_(isContiguous)(src)) {
382-
ptrdiff_t i;
383-
#pragma omp parallel for if(nIndices > TH_OMP_OVERHEAD_THRESHOLD) private(i)
384-
for (i = 0; i < nIndices; i++) {
385-
int64_t linearIndex = THTensor_(wrapLinearIndex)(index_data[i], srcElements);
386-
dst_data[i] = src_data[linearIndex];
387-
}
388-
} else {
389-
ptrdiff_t i;
390-
#pragma omp parallel for if(nIndices > TH_OMP_OVERHEAD_THRESHOLD) private(i)
391-
for (i = 0; i < nIndices; i++) {
392-
int64_t linearIndex = THTensor_(wrapLinearIndex)(index_data[i], srcElements);
393-
int64_t dataOffset = THTensor_(dataOffset)(src, linearIndex);
394-
dst_data[i] = src_data[dataOffset];
383+
int isContiguous = THTensor_(isContiguous)(src);
384+
385+
// Exceptions must not be thrown across OpenMP parallel sections, so we
386+
// record the value of the invalid index and throw the exception after the
387+
// loop.
388+
int64_t invalidIdx = -1;
389+
390+
ptrdiff_t i;
391+
#pragma omp parallel for if(nIndices > TH_OMP_OVERHEAD_THRESHOLD) private(i)
392+
for (i = 0; i < nIndices; i++) {
393+
int64_t idx = index_data[i];
394+
if (idx < srcElements && idx >= -srcElements) {
395+
idx = THTensor_(wrapLinearIndex)(idx, srcElements);
396+
if (isContiguous) {
397+
dst_data[i] = src_data[idx];
398+
} else {
399+
dst_data[i] = src_data[THTensor_(dataOffset)(src, idx)];
400+
}
401+
} else {
402+
THAtomicCompareAndSwapLong(&invalidIdx, -1, idx);
395403
}
396404
}
397405

406+
if (invalidIdx >= 0) {
407+
THTensor_(checkLinearIndex)(invalidIdx, srcElements);
408+
}
409+
398410
THLongTensor_free(index);
399411
THTensor_(freeCopyTo)(dst, r_);
400412
}
@@ -411,6 +423,7 @@ void THTensor_(put)(THTensor *tensor, THLongTensor *index, THTensor *src, int ac
411423
int is_contiguous = THTensor_(isContiguous)(tensor);
412424

413425
TH_TENSOR_APPLY2(int64_t, index, real, src,
426+
THTensor_(checkLinearIndex)(*index_data, numel);
414427
int64_t linearIndex = THTensor_(wrapLinearIndex)(*index_data, numel);
415428
int64_t dataOffset = is_contiguous ? linearIndex : THTensor_(dataOffset)(tensor, linearIndex);
416429
if (accumulate) {

0 commit comments

Comments
 (0)