@@ -1013,37 +1013,51 @@ def rand_tensor(shape):
10131013
10141014 @unittest .skipIf (not TEST_CUDA or TEST_WITH_ROCM or IS_WINDOWS , "Flash Attention was not built for this system" )
10151015 @parametrize ("type" , ["dense" , "nested" ])
1016- def test_scaled_dot_product_attention_fused_kernels_packed_accuracy (self , type : str ):
1016+ @parametrize ("fused_kernel" , ["flash" , "mem_efficeint" ])
1017+ def test_scaled_dot_product_attention_fused_kernels_packed_accuracy (self , type : str , fused_kernel : str ):
10171018 def rand_nt (shape ):
10181019 batch , seq_len , num_heads , head_dim = shape
1019- return torch .nested .nested_tensor ([torch .randn (seq_len , 3 * num_heads * head_dim ,
1020- device = "cuda" , dtype = torch .float16 )* 10 for _ in range (batch )])
1020+ tensors = [torch .randn (seq_len , 3 * num_heads * head_dim ) * 10 for _ in range (batch )]
1021+ return (torch .nested .nested_tensor (tensors , device = "cuda" , dtype = torch .float32 ),
1022+ torch .nested .nested_tensor (tensors , device = "cuda" , dtype = torch .float16 ))
1023+
10211024 def rand_tensor (shape ):
10221025 batch , seq_len , num_heads , head_dim = shape
1023- return torch .randn (batch , seq_len , 3 * num_heads * head_dim , device = "cuda" , dtype = torch .float16 ) * 10
1026+ tensor = torch .randn (batch , seq_len , 3 * num_heads * head_dim , device = "cuda" , dtype = torch .float32 ) * 10
1027+ return tensor , tensor .to (dtype = torch .float16 )
10241028
10251029 batch_size , seq_len , num_heads , head_dim = 16 , 8 , 4 , 64
10261030 shape = (batch_size , seq_len , num_heads , head_dim )
10271031
10281032 # Test Packed
1029- qkv = rand_tensor (shape ) if type == "dense" else rand_nt (shape )
1033+ qkv , qkv_low_precision = rand_tensor (shape ) if type == "dense" else rand_nt (shape )
10301034 query , key , value = qkv .chunk (3 , dim = - 1 )
1035+ query_lp , key_lp , value_lp = qkv_low_precision .chunk (3 , dim = - 1 )
10311036
10321037 query = query .view (batch_size , - 1 , num_heads , head_dim ).transpose (1 , 2 )
1033- value = value .view (batch_size , - 1 , num_heads , head_dim ).transpose (1 , 2 )
10341038 key = key .view (batch_size , - 1 , num_heads , head_dim ).transpose (1 , 2 )
1039+ value = value .view (batch_size , - 1 , num_heads , head_dim ).transpose (1 , 2 )
10351040
1036- with sdp_kernel ( enable_math = False ):
1037- actual = torch . nn . functional . _scaled_dot_product_attention (
1038- query , key , value , attn_mask = None , dropout_p = 0.0 , need_attn_weights = False , is_causal = False )
1041+ query_lp = query_lp . view ( batch_size , - 1 , num_heads , head_dim ). transpose ( 1 , 2 )
1042+ key_lp = key_lp . view ( batch_size , - 1 , num_heads , head_dim ). transpose ( 1 , 2 )
1043+ value_lp = value_lp . view ( batch_size , - 1 , num_heads , head_dim ). transpose ( 1 , 2 )
10391044
1040- with sdp_kernel (enable_flash = False ):
1041- math_query = query .contiguous ().to (dtype = torch .float32 )
1042- math_key = key .contiguous ().to (dtype = torch .float32 )
1043- math_value = value .contiguous ().to (dtype = torch .float32 )
1045+ if fused_kernel == "flash" :
1046+ with sdp_kernel (enable_mem_efficient = False , enable_math = False ):
1047+ actual = torch .nn .functional ._scaled_dot_product_attention (
1048+ query_lp , key_lp , value_lp , attn_mask = None , dropout_p = 0.0 , need_attn_weights = False , is_causal = False )
1049+ elif fused_kernel == "mem_efficeint" :
1050+ with sdp_kernel (enable_flash = False , enable_math = False ):
1051+ actual = torch .nn .functional ._scaled_dot_product_attention (
1052+ query_lp , key_lp , value_lp , attn_mask = None , dropout_p = 0.0 , need_attn_weights = False , is_causal = False )
1053+
1054+ with sdp_kernel (enable_flash = False , enable_mem_efficient = False ):
1055+ math_query = query .contiguous ()
1056+ math_key = key .contiguous ()
1057+ math_value = value .contiguous ()
10441058
10451059 math_ref = torch .nn .functional ._scaled_dot_product_attention (math_query , math_key , math_value ,
1046- attn_mask = None , dropout_p = 0.0 , need_attn_weights = False , is_causal = False )
1060+ attn_mask = None , dropout_p = 0.0 , need_attn_weights = False , is_causal = False )
10471061
10481062 actual_test = actual [0 ]
10491063 math_ref_test = math_ref [0 ]
@@ -1052,7 +1066,9 @@ def rand_tensor(shape):
10521066 actual_test = torch .nested .to_padded_tensor (actual_test .contiguous (), padding = 0.0 )
10531067 math_ref_test = torch .nested .to_padded_tensor (math_ref_test , padding = 0.0 )
10541068
1055- self .assertEqual (actual_test .to (dtype = torch .float32 ).contiguous (),math_ref_test .contiguous (), atol = 5e-3 , rtol = 5e-3 )
1069+ self .assertEqual (actual_test .to (dtype = torch .float32 ).contiguous (),
1070+ math_ref_test .contiguous (), atol = 5e-3 , rtol = 5e-3 )
1071+
10561072
10571073 @unittest .skipIf (not TEST_CUDA , "CUDA unavailable" )
10581074 def test_sdp_runtime_dispatch (self ):
0 commit comments