Skip to content

as_strided_backward in expanded case & dynamically created grad_fn for views #8626

@ssnl

Description

@ssnl
>>> x = torch.zeros(2, requires_grad=True)
>>> xx = x.expand(3, 2)
>>> z = torch.randn(3, 2)
>>> torch.autograd.grad((xx * z).mean(), x)[0]
tensor([ 0.4419, -0.1242])
>>> torch.autograd.grad((xx.as_strided([3,2], xx.stride()) * z).mean(), x)[0]  # reshape(3, 2) works too
tensor([ 0.5057, -0.2912])

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions