Skip to content

[CUDAGraph] Silent failure when graphs capture attempted on wrong device  #87894

@Aidyn-A

Description

@Aidyn-A

🐛 Describe the bug

The following code snippet silently fails to capture the graph due to memory being allocated on one device while the graphs capture being attempted on other.

import torch

device = torch.device("cuda:0")
x = torch.randn(10, dtype=torch.float32, device=device)
y = torch.randn(10, dtype=torch.float32, device=device)
z = torch.zeros(10, dtype=torch.float32, device=device)

with torch.cuda.device('cuda:1'): # Wrong device
    g = torch.cuda.CUDAGraph()
    with torch.cuda.graph(g):
        z = x + y

    for i in range(3):
        x.normal_()
        y.normal_()
        g.replay()
        print(z) # One would expect it to print different values each iteration, 
                 # but it does not because the current_device is 0 
                 # while all the tensors are on device 1

    print(f'Test passed')

I believe it should raise an error. However, I couldn't think of any easy way of doing it.
cc @ngimel @ptrblck

Versions

Latest

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: cudaRelated to torch.cuda, and CUDA support in generaltriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions