Skip to content

Commit 2180f34

Browse files
authored
[SDPA] Fix bug in parsing scaled_dot_product_attention arguments (#95311) (#95397)
Fixes #95266 Pull Request resolved: #95311 Approved by: https://github.com/cpuhrsch
1 parent a90b4f0 commit 2180f34

File tree

3 files changed

+82
-10
lines changed

3 files changed

+82
-10
lines changed

test/dynamo/test_dynamic_shapes.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,11 @@ def make_dynamic_cls(cls):
6060
# Cannot call sizes() on tensor with symbolic sizes/strides
6161
)
6262

63+
unittest.expectedFailure(
64+
DynamicShapesMiscTests.test_parsing_sdpa_dynamic_shapes
65+
# Cannot call sizes() on tensor with symbolic sizes/strides
66+
)
67+
6368

6469
# DynamicShapesSubGraphTests
6570
unittest.expectedFailure(

test/dynamo/test_misc.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3145,6 +3145,53 @@ def forward(self, query, key, value):
31453145
self.assertEqual(compiled.device.index, 0)
31463146
self.assertEqual(compiled.dtype, torch.float16)
31473147

3148+
@unittest.skipIf(
3149+
not PLATFORM_SUPPORTS_FUSED_SDPA or not SM80OrLater,
3150+
"Can't run fused SDPA on this platform",
3151+
)
3152+
def test_parsing_sdpa(self):
3153+
class MyModule(torch.nn.Module):
3154+
def forward(self, query, key, value):
3155+
out = F.scaled_dot_product_attention(query, key, value, None, 0, True)
3156+
out = F.scaled_dot_product_attention(
3157+
query=query,
3158+
key=key,
3159+
value=value,
3160+
attn_mask=None,
3161+
dropout_p=0,
3162+
is_causal=True,
3163+
)
3164+
out = F.scaled_dot_product_attention(
3165+
query,
3166+
key=key,
3167+
value=value,
3168+
attn_mask=None,
3169+
dropout_p=0,
3170+
is_causal=True,
3171+
)
3172+
out = F.scaled_dot_product_attention(
3173+
query, key, value, None, dropout_p=0, is_causal=True
3174+
)
3175+
return out
3176+
3177+
device = "cuda"
3178+
dtype = torch.float16
3179+
seq_len_q = 1
3180+
seq_len_k = 1
3181+
head_dim = 8
3182+
query = torch.ones(
3183+
1, 8, seq_len_q, head_dim, device=device, dtype=dtype, requires_grad=True
3184+
)
3185+
key = torch.ones(
3186+
1, 8, seq_len_k, head_dim, device=device, dtype=dtype, requires_grad=True
3187+
)
3188+
value = torch.ones(
3189+
1, 8, seq_len_k, head_dim, device=device, dtype=dtype, requires_grad=True
3190+
)
3191+
module = MyModule()
3192+
opt_mod = torch._dynamo.optimize("inductor")(module)
3193+
opt_mod(query, key, value)
3194+
31483195
def test_autocast_cpu(self):
31493196
class MyModule(torch.nn.Module):
31503197
def forward(self, x):

torch/_dynamo/variables/torch.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -481,9 +481,34 @@ def get_state_from_generator():
481481
if self.value == torch._C._nn.scaled_dot_product_attention:
482482
# See:[Note] SDPA_flash's meta function returns incorrect Philox seed and offset
483483
# in pytorch/torch/_meta_registrations.py
484-
fake_query = args[0].as_proxy().node.meta["example_value"]
485-
fake_key = args[1].as_proxy().node.meta["example_value"]
486-
fake_value = args[2].as_proxy().node.meta["example_value"]
484+
all_kwargs = kwargs.copy()
485+
all_kwargs.update(
486+
dict(
487+
zip(
488+
(
489+
"query",
490+
"key",
491+
"value",
492+
"attn_mask",
493+
"dropout_p",
494+
"is_causal",
495+
),
496+
args,
497+
)
498+
)
499+
)
500+
fake_query = all_kwargs["query"].as_proxy().node.meta["example_value"]
501+
fake_key = all_kwargs["key"].as_proxy().node.meta["example_value"]
502+
fake_value = all_kwargs["value"].as_proxy().node.meta["example_value"]
503+
fake_mask = all_kwargs.get("attn_mask")
504+
if isinstance(fake_mask, TensorVariable):
505+
fake_mask = fake_mask.as_proxy().node.meta["example_value"]
506+
else:
507+
fake_mask = None
508+
dropout_p = kwargs.get("dropout_p")
509+
dropout_p = dropout_p.value if dropout_p is not None else 0.0
510+
is_causal = kwargs.get("is_causal")
511+
is_causal = is_causal.value if is_causal is not None else False
487512
# We look through the stack to find a cuda autocast context
488513
# If we do we will convert the fake tensors to torch.float16
489514
is_cuda_autocast_context = False
@@ -502,15 +527,10 @@ def get_state_from_generator():
502527
fake_value = fake_value.clone().to(amp_dtype)
503528

504529
backend_choice = torch._fused_sdp_choice(
505-
fake_query, fake_key, fake_value
530+
fake_query, fake_key, fake_value, fake_mask, dropout_p, is_causal
506531
)
507532
if backend_choice == torch.backends.cuda.SDPBackend.FLASH_ATTENTION:
508-
dropout_p = kwargs.get("dropout_p")
509-
# Lets see if they passed it in as not an arg
510-
if len(args) >= 5:
511-
dropout_p = args[4]
512-
513-
if dropout_p is not None and dropout_p.value != 0.0:
533+
if dropout_p is not None and dropout_p != 0.0:
514534
unimplemented(
515535
"FlashAttention with dropout is not supported in cuda graphs"
516536
)

0 commit comments

Comments
 (0)