Skip to content

Commit 9459c62

Browse files
committed
Fix bug in identity cuda plaguing compute 5.2
* This is similar to the bug in triangle fixed in 144a2db
1 parent a1d6213 commit 9459c62

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

src/backend/cuda/kernel/identity.hpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,23 +23,26 @@ namespace kernel
2323
__global__
2424
static void identity_kernel(Param<T> out, int blocks_x, int blocks_y)
2525
{
26-
unsigned idz = blockIdx.x / blocks_x;
27-
unsigned idw = blockIdx.y / blocks_y;
26+
const dim_t idz = blockIdx.x / blocks_x;
27+
const dim_t idw = blockIdx.y / blocks_y;
2828

29-
unsigned blockIdx_x = blockIdx.x - idz * blocks_x;
30-
unsigned blockIdx_y = blockIdx.y - idw * blocks_y;
29+
const dim_t blockIdx_x = blockIdx.x - idz * blocks_x;
30+
const dim_t blockIdx_y = blockIdx.y - idw * blocks_y;
3131

32-
unsigned idx = threadIdx.x + blockIdx_x * blockDim.x;
33-
unsigned idy = threadIdx.y + blockIdx_y * blockDim.y;
32+
const dim_t idx = threadIdx.x + blockIdx_x * blockDim.x;
33+
const dim_t idy = threadIdx.y + blockIdx_y * blockDim.y;
3434

3535
if(idx >= out.dims[0] ||
3636
idy >= out.dims[1] ||
3737
idz >= out.dims[2] ||
3838
idw >= out.dims[3])
3939
return;
4040

41+
const T one = scalar<T>(1);
42+
const T zero = scalar<T>(0);
43+
4144
T *ptr = out.ptr + idz * out.strides[2] + idw * out.strides[3];
42-
T val = (idx == idy) ? scalar<T>(1) : scalar<T>(0);
45+
T val = (idx == idy) ? one : zero;
4346
ptr[idx + idy * out.strides[1]] = val;
4447
}
4548

0 commit comments

Comments
 (0)