>>> x = torch.zeros(2, requires_grad=True)
>>> xx = x.expand(3, 2)
>>> z = torch.randn(3, 2)
>>> torch.autograd.grad((xx * z).mean(), x)[0]
tensor([ 0.4419, -0.1242])
>>> torch.autograd.grad((xx.as_strided([3,2], xx.stride()) * z).mean(), x)[0] # reshape(3, 2) works too
tensor([ 0.5057, -0.2912])