Skip to content

Functionalization behavior change when handling indexed assignment. #91725

@zhxchen17

Description

@zhxchen17

🐛 Describe the bug

I'm trying to functionalize a graph like the following:

def foo(x, l):
    assert len(x.shape) == 3
    n_elem = x.shape[0]
    tmp = (
        torch.eye(4, dtype=x.dtype, device=x.device)
        .unsqueeze(0)
        .repeat(n_elem, 1, 1)
    )
    tmp = tmp.reshape(n_elem, 4, 4)
    f = x[..., 0, 0]

    tmp[..., 2, 2] = f / l

    return tmp

Before #91029, we can correctly generate a graph like the following:

        select_copy: f32[s1, s13] = torch.ops.aten.select_copy.int(arg1_1, 1, 0);  arg1_1 = None	
        select_copy_1: f32[s1] = torch.ops.aten.select_copy.int(select_copy, 1, 0);  select_copy = None	
        div: f32[s1] = torch.ops.aten.div.Tensor(select_copy_1, 200);  select_copy_1 = None	
        select_copy_2: f32[s1, 4] = torch.ops.aten.select_copy.int(view_copy_1, 1, 2);  view_copy_1 = None	
        select_copy_3: f32[s1] = torch.ops.aten.select_copy.int(select_copy_2, 1, 2);  select_copy_2 = None	
        copy: f32[s1] = torch.ops.aten.copy.default(select_copy_3, div);  select_copy_3 = div = None		
        view_copy_2: f32[s1, 4, 4] = torch.ops.aten.view_copy.default(repeat, [sym_size_4, 4, 4])	
        select_copy_4: f32[s1, 4] = torch.ops.aten.select_copy.int(view_copy_2, 1, 2)	
        select_scatter: f32[s1, 4] = torch.ops.aten.select_scatter.default(select_copy_4, copy, 1, 2);  select_copy_4 = copy = None	
        select_scatter_1: f32[s1, 4, 4] = torch.ops.aten.select_scatter.default(view_copy_2, select_scatter, 1, 2);  view_copy_2 = select_scatter = None	
        sym_size_5: Sym(s1) = torch.ops.aten.sym_size(repeat, 0);  repeat = None	
        view_copy_3: f32[s1, 4, 4] = torch.ops.aten.view_copy.default(select_scatter_1, [sym_size_5, 4, 4]);  select_scatter_1 = sym_size_5 = None	
        view_copy_4: f32[s1, 4, 4] = torch.ops.aten.view_copy.default(view_copy_3, [sym_size_4, 4, 4]);  view_copy_3 = sym_size_4 = None

After #91029, it seems like this line

    tmp[..., 2, 2] = f / l

won't be captured at all causing tracer to capture the original tmp rather than updated tmp from assignment.

cc @bdhirsh Thanks!

Versions

PyTorch master

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions