Skip to content

Commit 0aef791

Browse files
committed
update accuracy test
1 parent 573a1a5 commit 0aef791

File tree

1 file changed

+31
-15
lines changed

1 file changed

+31
-15
lines changed

test/test_transformers.py

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1013,37 +1013,51 @@ def rand_tensor(shape):
10131013

10141014
@unittest.skipIf(not TEST_CUDA or TEST_WITH_ROCM or IS_WINDOWS, "Flash Attention was not built for this system")
10151015
@parametrize("type", ["dense", "nested"])
1016-
def test_scaled_dot_product_attention_fused_kernels_packed_accuracy(self, type: str):
1016+
@parametrize("fused_kernel", ["flash", "mem_efficeint"])
1017+
def test_scaled_dot_product_attention_fused_kernels_packed_accuracy(self, type: str, fused_kernel: str):
10171018
def rand_nt(shape):
10181019
batch, seq_len, num_heads, head_dim = shape
1019-
return torch.nested.nested_tensor([torch.randn(seq_len, 3 * num_heads * head_dim,
1020-
device="cuda", dtype=torch.float16)*10 for _ in range(batch)])
1020+
tensors = [torch.randn(seq_len, 3 * num_heads * head_dim) * 10 for _ in range(batch)]
1021+
return (torch.nested.nested_tensor(tensors, device="cuda", dtype=torch.float32),
1022+
torch.nested.nested_tensor(tensors, device="cuda", dtype=torch.float16))
1023+
10211024
def rand_tensor(shape):
10221025
batch, seq_len, num_heads, head_dim = shape
1023-
return torch.randn(batch, seq_len, 3 * num_heads * head_dim, device="cuda", dtype=torch.float16) * 10
1026+
tensor = torch.randn(batch, seq_len, 3 * num_heads * head_dim, device="cuda", dtype=torch.float32) * 10
1027+
return tensor, tensor.to(dtype=torch.float16)
10241028

10251029
batch_size, seq_len, num_heads, head_dim = 16, 8, 4, 64
10261030
shape = (batch_size, seq_len, num_heads, head_dim)
10271031

10281032
# Test Packed
1029-
qkv = rand_tensor(shape) if type == "dense" else rand_nt(shape)
1033+
qkv, qkv_low_precision = rand_tensor(shape) if type == "dense" else rand_nt(shape)
10301034
query, key, value = qkv.chunk(3, dim=-1)
1035+
query_lp, key_lp, value_lp = qkv_low_precision.chunk(3, dim=-1)
10311036

10321037
query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
1033-
value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
10341038
key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
1039+
value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
10351040

1036-
with sdp_kernel(enable_math=False):
1037-
actual = torch.nn.functional._scaled_dot_product_attention(
1038-
query, key, value, attn_mask=None, dropout_p=0.0, need_attn_weights=False, is_causal=False)
1041+
query_lp = query_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
1042+
key_lp = key_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
1043+
value_lp = value_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
10391044

1040-
with sdp_kernel(enable_flash=False):
1041-
math_query = query.contiguous().to(dtype=torch.float32)
1042-
math_key = key.contiguous().to(dtype=torch.float32)
1043-
math_value = value.contiguous().to(dtype=torch.float32)
1045+
if fused_kernel == "flash":
1046+
with sdp_kernel(enable_mem_efficient=False, enable_math=False):
1047+
actual = torch.nn.functional._scaled_dot_product_attention(
1048+
query_lp, key_lp, value_lp, attn_mask=None, dropout_p=0.0, need_attn_weights=False, is_causal=False)
1049+
elif fused_kernel == "mem_efficeint":
1050+
with sdp_kernel(enable_flash=False, enable_math=False):
1051+
actual = torch.nn.functional._scaled_dot_product_attention(
1052+
query_lp, key_lp, value_lp, attn_mask=None, dropout_p=0.0, need_attn_weights=False, is_causal=False)
1053+
1054+
with sdp_kernel(enable_flash=False, enable_mem_efficient=False):
1055+
math_query = query.contiguous()
1056+
math_key = key.contiguous()
1057+
math_value = value.contiguous()
10441058

10451059
math_ref = torch.nn.functional._scaled_dot_product_attention(math_query, math_key, math_value,
1046-
attn_mask=None, dropout_p=0.0, need_attn_weights=False, is_causal=False)
1060+
attn_mask=None, dropout_p=0.0, need_attn_weights=False, is_causal=False)
10471061

10481062
actual_test = actual[0]
10491063
math_ref_test = math_ref[0]
@@ -1052,7 +1066,9 @@ def rand_tensor(shape):
10521066
actual_test = torch.nested.to_padded_tensor(actual_test.contiguous(), padding=0.0)
10531067
math_ref_test = torch.nested.to_padded_tensor(math_ref_test, padding=0.0)
10541068

1055-
self.assertEqual(actual_test.to(dtype=torch.float32).contiguous(),math_ref_test.contiguous(), atol=5e-3, rtol=5e-3)
1069+
self.assertEqual(actual_test.to(dtype=torch.float32).contiguous(),
1070+
math_ref_test.contiguous(), atol=5e-3, rtol=5e-3)
1071+
10561072

10571073
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
10581074
def test_sdp_runtime_dispatch(self):

0 commit comments

Comments
 (0)