Skip to content

Commit a570975

Browse files
committed
from review
1 parent 74443e5 commit a570975

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

test/functorch/test_vmap.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def foo(x):
324324
tensor = torch.randn(2, 3)
325325
result = vmap(foo, out_dims=(0, None))(tensor)
326326
self.assertEqual(result[1], 'hello world')
327-
self.assertIsInstance(result[0], torch.Tensor)
327+
self.assertEqual(result[0], tensor)
328328

329329
def foo(x):
330330
x.add_(1)

torch/_functorch/vmap.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def _maybe_remove_batch_dim(name, batched_output, vmap_level, batch_size, out_di
145145
# out_dim is non None
146146
if not isinstance(batched_output, torch.Tensor):
147147
raise ValueError(f'vmap({name}, ...): `{name}` must only return '
148-
f'Tensors, got type {type(batched_output)} as a return. '
148+
f'Tensors, got type {type(batched_output)}. '
149149
'Did you mean to set out_dim= to None for output?')
150150

151151
return _remove_batch_dim(batched_output, vmap_level, batch_size, out_dim)

0 commit comments

Comments
 (0)