Skip to content

Commit 488cc11

Browse files
author
xjia
committed
refine softmax
1 parent b6fb9aa commit 488cc11

File tree

1 file changed

+60
-12
lines changed

1 file changed

+60
-12
lines changed

FasterTransformer/fastertransformer/cuda/open_attention.cu

Lines changed: 60 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,39 @@ T blockReduceSum(T val)
6161

6262
return val;
6363
}
64+
65+
template <typename T>
66+
__inline__ __device__
67+
T warpReduceMax(T val)
68+
{
69+
for(int mask = 16; mask > 0; mask >>= 1)
70+
val = max(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32));
71+
return val;
72+
}
73+
74+
/* Calculate the maximum of all elements in a block */
75+
template <typename T>
76+
__inline__ __device__
77+
T blockReduceMax(T val)
78+
{
79+
static __shared__ T shared[32];
80+
int lane = threadIdx.x & 0x1f; // in-warp idx
81+
int wid = threadIdx.x >> 5; // warp idx
82+
83+
val = warpReduceMax(val); // get maxx in each warp
84+
85+
if(lane == 0) // record in-warp maxx by warp Idx
86+
shared[wid] = val;
87+
88+
__syncthreads();
89+
90+
91+
val = (threadIdx.x < (blockDim.x >> 5 )) ? shared[lane] : 0;
92+
val = warpReduceMax(val);
93+
94+
return val;
95+
}
96+
6497
__inline__ __device__
6598
int target_index(int id1, int id2, int id3, int id4, int dim_1, int dim_2, int dim_3, int dim_4)
6699
{
@@ -162,16 +195,25 @@ void softmax_kernel(T* qk_buf_, const T* attr_mask, const int batch_size, const
162195
int qk_offset = blockIdx.x * seq_len * seq_len;
163196
int mask_offset = batch_id * seq_len * seq_len;
164197

165-
__shared__ float s_sum;
198+
__shared__ float s_sum, s_max;
166199

167200
for(int i = 0; i < seq_len; ++i)
168201
{
169-
T qk = threadIdx.x < seq_len ? qk_buf_[threadIdx.x + qk_offset] : (T)(0.0f);
170-
T mask_val = threadIdx.x < seq_len ? attr_mask[threadIdx.x + mask_offset] : (T)(0.0f);
202+
float qk = threadIdx.x < seq_len ? (float)qk_buf_[threadIdx.x + qk_offset] : 0.0f;
203+
float mask_val = threadIdx.x < seq_len ? (float)attr_mask[threadIdx.x + mask_offset] : 0.0f;
171204

172-
mask_val = ((T)1.0f - mask_val) * (T)(-10000.0f);
205+
mask_val = (1.0f - mask_val) * -10000.0f;
206+
207+
float tmp = threadIdx.x < seq_len ? (float)(qk * (float)scaler + mask_val): -1e-20f;
208+
209+
float max_val = blockReduceMax<float>(tmp);
210+
211+
if(threadIdx.x == 0)
212+
s_max = max_val;
213+
__syncthreads();
214+
215+
qk = threadIdx.x < seq_len ? __expf(tmp - s_max) : 0.0f;
173216

174-
qk = __expf((float)(qk * scaler + mask_val));
175217
float sum_val = blockReduceSum<float>(qk);
176218

177219
if(threadIdx.x == 0)
@@ -181,7 +223,7 @@ void softmax_kernel(T* qk_buf_, const T* attr_mask, const int batch_size, const
181223
__syncthreads();
182224

183225
if(threadIdx.x < seq_len)
184-
qk_buf_[threadIdx.x + qk_offset] = qk / (T)s_sum;
226+
qk_buf_[threadIdx.x + qk_offset] = (T)(qk / s_sum);
185227

186228
qk_offset += seq_len;
187229
mask_offset += seq_len;
@@ -192,21 +234,27 @@ void softmax_kernel(T* qk_buf_, const T* attr_mask, const int batch_size, const
192234
template <typename T>
193235
__global__
194236
void softmax_kernel_v2(T* qk_buf_, const T* attr_mask, const int batch_size, const int head_num,
195-
const int seq_len, const T scaler)
237+
const int seq_len, const float scaler)
196238
{
197239
int batch_id = blockIdx.x / head_num / seq_len;
198240
int seq_id = blockIdx.x % seq_len;
199241
int qk_offset = blockIdx.x * seq_len;
200242
int mask_offset = batch_id * seq_len * seq_len + seq_id * seq_len;
201243

202-
__shared__ float s_sum;
244+
__shared__ float s_sum, s_max;
203245

204-
T qk = threadIdx.x < seq_len ? qk_buf_[threadIdx.x + qk_offset] : (T)(0.0f);
205-
T mask_val = threadIdx.x < seq_len ? attr_mask[threadIdx.x + mask_offset] : (T)(0.0f);
246+
float qk = threadIdx.x < seq_len ? (float)qk_buf_[threadIdx.x + qk_offset] : 0.0f;
247+
float mask_val = threadIdx.x < seq_len ? (float)attr_mask[threadIdx.x + mask_offset] : 0.0f;
206248

207-
mask_val = ((T)1.0f - mask_val) * (T)(-10000.0f);
249+
mask_val = (1.0f - mask_val) * -10000.0f;
250+
251+
float tmp = threadIdx.x < seq_len ? (float)(qk * (float)scaler + mask_val) : -1e-20f;
252+
float max_val = blockReduceMax<float>(tmp);
253+
if(threadIdx.x == 0)
254+
s_max = max_val;
255+
__syncthreads();
208256

209-
float qk_tmp = threadIdx.x < seq_len ? __expf((float)(qk * scaler + mask_val)) : 0.0f;
257+
float qk_tmp = threadIdx.x < seq_len ? __expf((float)(tmp - s_max)) : 0.0f;
210258
float sum_val = blockReduceSum<float>(qk_tmp);
211259

212260
if(threadIdx.x == 0)

0 commit comments

Comments
 (0)