@@ -61,6 +61,39 @@ T blockReduceSum(T val)
6161
6262 return val;
6363}
64+
65+ template <typename T>
66+ __inline__ __device__
67+ T warpReduceMax (T val)
68+ {
69+ for (int mask = 16 ; mask > 0 ; mask >>= 1 )
70+ val = max (val, __shfl_xor_sync (FINAL_MASK, val, mask, 32 ));
71+ return val;
72+ }
73+
74+ /* Calculate the maximum of all elements in a block */
75+ template <typename T>
76+ __inline__ __device__
77+ T blockReduceMax (T val)
78+ {
79+ static __shared__ T shared[32 ];
80+ int lane = threadIdx .x & 0x1f ; // in-warp idx
81+ int wid = threadIdx .x >> 5 ; // warp idx
82+
83+ val = warpReduceMax (val); // get maxx in each warp
84+
85+ if (lane == 0 ) // record in-warp maxx by warp Idx
86+ shared[wid] = val;
87+
88+ __syncthreads ();
89+
90+
91+ val = (threadIdx .x < (blockDim .x >> 5 )) ? shared[lane] : 0 ;
92+ val = warpReduceMax (val);
93+
94+ return val;
95+ }
96+
6497 __inline__ __device__
6598int target_index (int id1, int id2, int id3, int id4, int dim_1, int dim_2, int dim_3, int dim_4)
6699{
@@ -162,16 +195,25 @@ void softmax_kernel(T* qk_buf_, const T* attr_mask, const int batch_size, const
162195 int qk_offset = blockIdx .x * seq_len * seq_len;
163196 int mask_offset = batch_id * seq_len * seq_len;
164197
165- __shared__ float s_sum;
198+ __shared__ float s_sum, s_max ;
166199
167200 for (int i = 0 ; i < seq_len; ++i)
168201 {
169- T qk = threadIdx .x < seq_len ? qk_buf_[threadIdx .x + qk_offset] : (T)( 0 .0f ) ;
170- T mask_val = threadIdx .x < seq_len ? attr_mask[threadIdx .x + mask_offset] : (T)( 0 .0f ) ;
202+ float qk = threadIdx .x < seq_len ? ( float ) qk_buf_[threadIdx .x + qk_offset] : 0 .0f ;
203+ float mask_val = threadIdx .x < seq_len ? ( float ) attr_mask[threadIdx .x + mask_offset] : 0 .0f ;
171204
172- mask_val = ((T)1 .0f - mask_val) * (T)(-10000 .0f );
205+ mask_val = (1 .0f - mask_val) * -10000 .0f ;
206+
207+ float tmp = threadIdx .x < seq_len ? (float )(qk * (float )scaler + mask_val): -1e-20f ;
208+
209+ float max_val = blockReduceMax<float >(tmp);
210+
211+ if (threadIdx .x == 0 )
212+ s_max = max_val;
213+ __syncthreads ();
214+
215+ qk = threadIdx .x < seq_len ? __expf (tmp - s_max) : 0 .0f ;
173216
174- qk = __expf ((float )(qk * scaler + mask_val));
175217 float sum_val = blockReduceSum<float >(qk);
176218
177219 if (threadIdx .x == 0 )
@@ -181,7 +223,7 @@ void softmax_kernel(T* qk_buf_, const T* attr_mask, const int batch_size, const
181223 __syncthreads ();
182224
183225 if (threadIdx .x < seq_len)
184- qk_buf_[threadIdx .x + qk_offset] = qk / (T) s_sum;
226+ qk_buf_[threadIdx .x + qk_offset] = (T)( qk / s_sum) ;
185227
186228 qk_offset += seq_len;
187229 mask_offset += seq_len;
@@ -192,21 +234,27 @@ void softmax_kernel(T* qk_buf_, const T* attr_mask, const int batch_size, const
192234template <typename T>
193235__global__
194236void softmax_kernel_v2 (T* qk_buf_, const T* attr_mask, const int batch_size, const int head_num,
195- const int seq_len, const T scaler)
237+ const int seq_len, const float scaler)
196238{
197239 int batch_id = blockIdx .x / head_num / seq_len;
198240 int seq_id = blockIdx .x % seq_len;
199241 int qk_offset = blockIdx .x * seq_len;
200242 int mask_offset = batch_id * seq_len * seq_len + seq_id * seq_len;
201243
202- __shared__ float s_sum;
244+ __shared__ float s_sum, s_max ;
203245
204- T qk = threadIdx .x < seq_len ? qk_buf_[threadIdx .x + qk_offset] : (T)( 0 .0f ) ;
205- T mask_val = threadIdx .x < seq_len ? attr_mask[threadIdx .x + mask_offset] : (T)( 0 .0f ) ;
246+ float qk = threadIdx .x < seq_len ? ( float ) qk_buf_[threadIdx .x + qk_offset] : 0 .0f ;
247+ float mask_val = threadIdx .x < seq_len ? ( float ) attr_mask[threadIdx .x + mask_offset] : 0 .0f ;
206248
207- mask_val = ((T)1 .0f - mask_val) * (T)(-10000 .0f );
249+ mask_val = (1 .0f - mask_val) * -10000 .0f ;
250+
251+ float tmp = threadIdx .x < seq_len ? (float )(qk * (float )scaler + mask_val) : -1e-20f ;
252+ float max_val = blockReduceMax<float >(tmp);
253+ if (threadIdx .x == 0 )
254+ s_max = max_val;
255+ __syncthreads ();
208256
209- float qk_tmp = threadIdx .x < seq_len ? __expf ((float )(qk * scaler + mask_val )) : 0 .0f ;
257+ float qk_tmp = threadIdx .x < seq_len ? __expf ((float )(tmp - s_max )) : 0 .0f ;
210258 float sum_val = blockReduceSum<float >(qk_tmp);
211259
212260 if (threadIdx .x == 0 )
0 commit comments