Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions aten/src/ATen/native/TensorShape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <TH/THTensor.hpp>

#include <algorithm>
#include <vector>

namespace at {
namespace native {
Expand Down Expand Up @@ -604,5 +605,30 @@ int64_t numel(const Tensor& self) {
return self.pImpl->numel();
}

std::vector<Tensor> meshgrid(TensorList tensors) {
int64_t size = tensors.size();
AT_CHECK(size > 0, "meshgrid expects a non-empty TensorList");
std::vector<int64_t> shape(size);
for(int64_t i = 0; i < size; i++) {
switch (tensors[i].dim()) {
case 0:
shape[i] = 1;
break;
case 1:
shape[i] = tensors[i].size(0);
break;
default:
AT_ERROR("Expected scalar or 1D tensor in the tensor list but got: ", tensors[i]);
}
}
std::vector<Tensor> grids;
for(int64_t i = 0; i < size; i++) {
std::vector<int64_t> view_shape(size, 1);
view_shape[i] = -1;
grids.push_back(tensors[i].view(view_shape).expand(shape));
}
return grids;
}

}
}
3 changes: 3 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1805,3 +1805,6 @@

- func: get_device(Tensor self) -> int64_t
device_guard: False

- func: meshgrid(TensorList tensors) -> TensorList
variants: function
19 changes: 19 additions & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7750,6 +7750,25 @@ def test_is_nonzero(self):
self.assertFalse(torch.tensor([[0]]).is_nonzero())
self.assertTrue(torch.tensor([[1]]).is_nonzero())

def test_meshgrid(self):
a = torch.tensor(1)
b = torch.tensor([1, 2, 3])
c = torch.tensor([1, 2])
grid_a, grid_b, grid_c = torch.meshgrid([a, b, c])
self.assertEqual(grid_a.shape, torch.Size([1, 3, 2]))
self.assertEqual(grid_b.shape, torch.Size([1, 3, 2]))
self.assertEqual(grid_c.shape, torch.Size([1, 3, 2]))
expected_grid_a = torch.ones(1, 3, 2, dtype=torch.int64)
expected_grid_b = torch.tensor([[[1, 1],
[2, 2],
[3, 3]]])
expected_grid_c = torch.tensor([[[1, 2],
[1, 2],
[1, 2]]])
self.assertTrue(grid_a.equal(expected_grid_a))
self.assertTrue(grid_b.equal(expected_grid_b))
self.assertTrue(grid_c.equal(expected_grid_c))


# Functions to test negative dimension wrapping
METHOD = 1
Expand Down
33 changes: 33 additions & 0 deletions torch/_torch_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5833,3 +5833,36 @@ def parse_kwargs(desc):
Tensor: A 1-D tensor of size :math:`(\text{{window_length}},)` containing the window

""".format(**factory_common_args))


add_docstr(torch.meshgrid,
r"""
meshgrid(seq) -> seq

Take a sequence of :math:`N` tensors, each of which can be either scalar or 1-dimensional
vector, and create :math:`N` N-dimensional grids, where the :math:`i`th grid is defined by
expanding the :math:`i`th input over dimensions defined by other inputs.

Arguments:
seq (sequence of Tensors): sequence of scalars or 1 dimensional tensors. Scalars will be
treated as tensors of size :math:`(1,)` automatically.

Returns:
seq (sequence of Tensors): If the input has :math:`k` tensors of size
:math:`(N_1,), (N_2,), \ldots , (N_k,)`, then the output would also has :math:`k` tensors,

This comment was marked as off-topic.

where all tensors are of size :math:`(N_1, N_2, \ldots , N_k)`.

Example::

>>> x = torch.tensor([1, 2, 3])
>>> y = torch.tensor([4, 5, 6])
>>> grid_x, grid_y = torch.meshgrid([x, y])
>>> grid_x
tensor([[1, 1, 1],
[2, 2, 2],
[3, 3, 3]])
>>> grid_y
tensor([[4, 5, 6],
[4, 5, 6],
[4, 5, 6]])
""")