Skip to content

Commit 0e77eb4

Browse files
committed
check result
1 parent a2b7b7e commit 0e77eb4

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

test/functorch/test_vmap.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -350,11 +350,14 @@ def foo(x):
350350
def test_out_dims_normal_tensor(self):
351351

352352
def foo(x):
353-
return torch.zeros(3)
353+
return torch.arange(3)
354354

355355
tensor = torch.randn(2, 3)
356-
vmap(foo)(tensor)
357-
vmap(foo, out_dims=None)(tensor)
356+
result = vmap(foo)(tensor)
357+
self.assertEqual(result.shape, [2, 3])
358+
359+
result = vmap(foo, out_dims=None)(tensor)
360+
self.assertEqual(result, torch.arange(3))
358361

359362

360363
def test_pytree_returns(self):

0 commit comments

Comments
 (0)