File tree Expand file tree Collapse file tree 1 file changed +10
-7
lines changed
Expand file tree Collapse file tree 1 file changed +10
-7
lines changed Original file line number Diff line number Diff line change @@ -23,23 +23,26 @@ namespace kernel
2323 __global__
2424 static void identity_kernel (Param<T> out, int blocks_x, int blocks_y)
2525 {
26- unsigned idz = blockIdx.x / blocks_x;
27- unsigned idw = blockIdx.y / blocks_y;
26+ const dim_t idz = blockIdx.x / blocks_x;
27+ const dim_t idw = blockIdx.y / blocks_y;
2828
29- unsigned blockIdx_x = blockIdx.x - idz * blocks_x;
30- unsigned blockIdx_y = blockIdx.y - idw * blocks_y;
29+ const dim_t blockIdx_x = blockIdx.x - idz * blocks_x;
30+ const dim_t blockIdx_y = blockIdx.y - idw * blocks_y;
3131
32- unsigned idx = threadIdx.x + blockIdx_x * blockDim.x ;
33- unsigned idy = threadIdx.y + blockIdx_y * blockDim.y ;
32+ const dim_t idx = threadIdx.x + blockIdx_x * blockDim.x ;
33+ const dim_t idy = threadIdx.y + blockIdx_y * blockDim.y ;
3434
3535 if (idx >= out.dims [0 ] ||
3636 idy >= out.dims [1 ] ||
3737 idz >= out.dims [2 ] ||
3838 idw >= out.dims [3 ])
3939 return ;
4040
41+ const T one = scalar<T>(1 );
42+ const T zero = scalar<T>(0 );
43+
4144 T *ptr = out.ptr + idz * out.strides [2 ] + idw * out.strides [3 ];
42- T val = (idx == idy) ? scalar<T>( 1 ) : scalar<T>( 0 ) ;
45+ T val = (idx == idy) ? one : zero ;
4346 ptr[idx + idy * out.strides [1 ]] = val;
4447 }
4548
You can’t perform that action at this time.
0 commit comments