@@ -361,8 +361,11 @@ static ptrdiff_t THTensor_(dataOffset)(THTensor* tensor, ptrdiff_t linearIndex)
361361 return dataOffset ;
362362}
363363
364- static int64_t THTensor_ (wrapLinearIndex )(int64_t linearIndex , int64_t numel ) {
364+ static void THTensor_ (checkLinearIndex )(int64_t linearIndex , int64_t numel ) {
365365 THArgCheck (linearIndex < numel && linearIndex >= - numel , 2 , "out of range: %d out of %d" , (int )linearIndex , (int )numel );
366+ }
367+
368+ static int64_t THTensor_ (wrapLinearIndex )(int64_t linearIndex , int64_t numel ) {
366369 return linearIndex < 0 ? linearIndex + numel : linearIndex ;
367370}
368371
@@ -376,25 +379,34 @@ void THTensor_(take)(THTensor *r_, THTensor *src, THLongTensor *index)
376379 ptrdiff_t srcElements = THTensor_ (nElement )(src );
377380 real * src_data = THTensor_ (data )(src );
378381 real * dst_data = THTensor_ (data )(dst );
379-
380382 ptrdiff_t nIndices = THLongTensor_nElement (index );
381- if (THTensor_ (isContiguous )(src )) {
382- ptrdiff_t i ;
383- #pragma omp parallel for if(nIndices > TH_OMP_OVERHEAD_THRESHOLD) private(i)
384- for (i = 0 ; i < nIndices ; i ++ ) {
385- int64_t linearIndex = THTensor_ (wrapLinearIndex )(index_data [i ], srcElements );
386- dst_data [i ] = src_data [linearIndex ];
387- }
388- } else {
389- ptrdiff_t i ;
390- #pragma omp parallel for if(nIndices > TH_OMP_OVERHEAD_THRESHOLD) private(i)
391- for (i = 0 ; i < nIndices ; i ++ ) {
392- int64_t linearIndex = THTensor_ (wrapLinearIndex )(index_data [i ], srcElements );
393- int64_t dataOffset = THTensor_ (dataOffset )(src , linearIndex );
394- dst_data [i ] = src_data [dataOffset ];
383+ int isContiguous = THTensor_ (isContiguous )(src );
384+
385+ // Exceptions must not be thrown across OpenMP parallel sections, so we
386+ // record the value of the invalid index and throw the exception after the
387+ // loop.
388+ int64_t invalidIdx = -1 ;
389+
390+ ptrdiff_t i ;
391+ #pragma omp parallel for if(nIndices > TH_OMP_OVERHEAD_THRESHOLD) private(i)
392+ for (i = 0 ; i < nIndices ; i ++ ) {
393+ int64_t idx = index_data [i ];
394+ if (idx < srcElements && idx >= - srcElements ) {
395+ idx = THTensor_ (wrapLinearIndex )(idx , srcElements );
396+ if (isContiguous ) {
397+ dst_data [i ] = src_data [idx ];
398+ } else {
399+ dst_data [i ] = src_data [THTensor_ (dataOffset )(src , idx )];
400+ }
401+ } else {
402+ THAtomicCompareAndSwapLong (& invalidIdx , -1 , idx );
395403 }
396404 }
397405
406+ if (invalidIdx >= 0 ) {
407+ THTensor_ (checkLinearIndex )(invalidIdx , srcElements );
408+ }
409+
398410 THLongTensor_free (index );
399411 THTensor_ (freeCopyTo )(dst , r_ );
400412}
@@ -411,6 +423,7 @@ void THTensor_(put)(THTensor *tensor, THLongTensor *index, THTensor *src, int ac
411423 int is_contiguous = THTensor_ (isContiguous )(tensor );
412424
413425 TH_TENSOR_APPLY2 (int64_t , index , real , src ,
426+ THTensor_ (checkLinearIndex )(* index_data , numel );
414427 int64_t linearIndex = THTensor_ (wrapLinearIndex )(* index_data , numel );
415428 int64_t dataOffset = is_contiguous ? linearIndex : THTensor_ (dataOffset )(tensor , linearIndex );
416429 if (accumulate ) {
0 commit comments