@@ -46,7 +46,7 @@ multinomialAliasDrawKernel(int size, int64_t *output, int64_t *J, T *q, int64_t
4646 T bern_uniform = bernoulli[idx];
4747 int _mask = (int ) THCNumerics<T>::lt (bern_uniform, q[rand_ind]);
4848 output[idx] = J[rand_ind]*(1 -_mask) + (rand_ind+1L ) * _mask;
49- }
49+ }
5050}
5151
5252template <typename T>
@@ -94,15 +94,19 @@ template <typename T>
9494__global__ void renormRowsL1 (T* dist, long rows, long cols) {
9595 extern __shared__ unsigned char my_smem[];
9696 T *smem = reinterpret_cast <T *>(my_smem);
97-
97+ T zero = ScalarConvert<int , T>::to (0 );
98+ T val;
9899 for (int64_t row = blockIdx .x ; row < rows; row += gridDim .x ) {
99100 T sum = ScalarConvert<int , T>::to (0 );
100101 for (int64_t col = threadIdx .x ; col < cols; col += blockDim .x ) {
101- sum = THCNumerics<T>::add (sum, dist[row * cols + col]);
102+ val = dist[row * cols + col];
103+ assert (THCNumerics<T>::ge (val, zero));
104+ sum = THCNumerics<T>::add (sum, val);
102105 }
103106
104- sum = reduceBlock (smem, blockDim .x , sum, ReduceAdd<T, T>(), ScalarConvert< int , T>:: to ( 0 ) );
107+ sum = reduceBlock (smem, blockDim .x , sum, ReduceAdd<T, T>(), zero );
105108 if (threadIdx .x == 0 ) {
109+ assert (THCNumerics<T>::gt (sum, zero));
106110 smem[0 ] = sum;
107111 }
108112 __syncthreads ();
@@ -169,10 +173,11 @@ sampleMultinomialOnce(int64_t* dest,
169173 // Each block handles one distribution
170174 // First pass, find the total sum of the distribution
171175 AccT sum = accZero;
176+ T val;
172177 for (int cat = threadIdx .x ; cat < categories; cat += blockDim .x ) {
173- sum = THCNumerics<AccT>:: add (
174- sum,
175- ScalarConvert<T, AccT>::to (dist[curDist * categories + cat] ));
178+ val = dist[curDist * categories + cat];
179+ assert (THCNumerics<T>:: ge (val, zero));
180+ sum = THCNumerics<AccT>:: add (sum, ScalarConvert<T, AccT>::to (val ));
176181 }
177182
178183 // threadIdx.x == 0 has the sum value from this
@@ -182,6 +187,7 @@ sampleMultinomialOnce(int64_t* dest,
182187 if (threadIdx .x == 0 ) {
183188 // Make sure the sum of our distribution didn't overflow
184189 assert (!isinf (sum));
190+ assert (THCNumerics<AccT>::gt (sum, accZero));
185191
186192 asmem[0 ] = sum;
187193 smem[0 ] = sampled[curDist];
0 commit comments