Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions aten/src/ATen/native/TensorFactories.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,42 @@ Tensor rand_like(const Tensor& self, const Type& dtype) {
return at::native::rand(dtype, self.sizes());
}

Tensor randint(const Type& dtype, int64_t high, IntList size, Generator* generator) {
Tensor result = dtype.tensor(size);
return result.random_(0, high, generator);
}

Tensor randint(const Type& dtype, int64_t low, int64_t high, IntList size, Generator* generator) {
Tensor result = dtype.tensor(size);
return result.random_(low, high, generator);
}

Tensor& randint_out(Tensor& result, int64_t high, IntList size, Generator* generator) {
result.resize_(size);
return result.random_(0, high, generator);
}

Tensor& randint_out(Tensor& result, int64_t low, int64_t high, IntList size, Generator* generator) {
result.resize_(size);
return result.random_(low, high, generator);
}

Tensor randint_like(const Tensor& self, int64_t high) {
return at::native::randint_like(self, high, self.type());
}

Tensor randint_like(const Tensor& self, int64_t low, int64_t high) {
return at::native::randint_like(self, low, high, self.type());
}

Tensor randint_like(const Tensor& self, int64_t high, const Type& dtype) {
return at::native::randint(dtype, high, self.sizes(), nullptr);
}

Tensor randint_like(const Tensor& self, int64_t low, int64_t high, const Type& dtype) {
return at::native::randint(dtype, low, high, self.sizes(), nullptr);
}

Tensor randn(const Type& dtype, IntList size, Generator* generator) {
Tensor result = dtype.tensor(size);
return result.normal_(0, 1, generator);
Expand Down
24 changes: 24 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,30 @@
- func: rand_like(Tensor self, *, Type dtype) -> Tensor
variants: function

- func: randint(Type dtype, int64_t high, IntList size, *, Generator* generator=nullptr) -> Tensor
variants: function

- func: randint(Type dtype, int64_t low, int64_t high, IntList size, *, Generator* generator=nullptr) -> Tensor
variants: function

- func: randint_out(Tensor result, int64_t high, IntList size, *, Generator* generator=nullptr) -> Tensor
variants: function

- func: randint_out(Tensor result, int64_t low, int64_t high, IntList size, *, Generator* generator=nullptr) -> Tensor
variants: function

- func: randint_like(Tensor self, int64_t high) -> Tensor
variants: function

- func: randint_like(Tensor self, int64_t low, int64_t high) -> Tensor
variants: function

- func: randint_like(Tensor self, int64_t high, *, Type dtype) -> Tensor
variants: function

- func: randint_like(Tensor self, int64_t low, int64_t high, *, Type dtype) -> Tensor
variants: function

- func: randn(Type dtype, IntList size, *, Generator* generator=nullptr) -> Tensor
variants: function

Expand Down
24 changes: 24 additions & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2541,6 +2541,30 @@ def test_rand(self):
torch.rand(SIZE, SIZE, out=res2)
self.assertEqual(res1, res2)

def test_randint(self):
torch.manual_seed(123456)
res1 = torch.randint(0, 6, (SIZE, SIZE))
res2 = torch.Tensor()
torch.manual_seed(123456)
torch.randint(0, 6, (SIZE, SIZE), out=res2)
torch.manual_seed(123456)
res3 = torch.randint(6, (SIZE, SIZE))
res4 = torch.Tensor()
torch.manual_seed(123456)
torch.randint(6, (SIZE, SIZE), out=res4)
self.assertEqual(res1, res2)
self.assertEqual(res1, res3)
self.assertEqual(res1, res4)
self.assertEqual(res2, res3)
self.assertEqual(res2, res4)
self.assertEqual(res3, res4)
res1 = res1.view(-1)
high = (res1 < 6).type(torch.LongTensor)
low = (res1 >= 0).type(torch.LongTensor)
tensorSize = res1.size()[0]
assert(tensorSize == high.sum())
assert(tensorSize == low.sum())

def test_randn(self):
torch.manual_seed(123456)
res1 = torch.randn(SIZE, SIZE)
Expand Down
39 changes: 39 additions & 0 deletions torch/_torch_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4073,6 +4073,45 @@

""")

add_docstr(torch.randint,
r"""
randint(low=0, high, sizes, out=None, dtype=torch.float32) -> Tensor

Returns a tensor filled with random integers generated uniformly
between :attr:`low` (inclusive) and :attr:`high` (exclusive).

The shape of the tensor is defined by the variable argument :attr:`sizes`.

Args:
low (int, optional): Lowest integer to be drawn from the distribution. Default: 0.
high (int): One above the highest integer to be drawn from the distribution.
sizes (tuple): a tuple defining the shape of the output tensor.
out (Tensor, optional): the output tensor
dtype (:class:`torch.dtype`, optional): the desired type of returned Tensor. Default: torch.float32

Example::

>>> torch.randint(3, 5, (3,))

4
4
3
[torch.FloatTensor of size (3,)]

>>> torch.randint(3, 10, (2,2), dtype=torch.long)

7 5
9 4
[torch.LongTensor of size (2,2)]

>>> torch.randint(3, 10, (2,2))

6 8
9 4
[torch.FloatTensor of size (2,2)]

""")

add_docstr(torch.randn,
r"""
randn(*sizes, out=None) -> Tensor
Expand Down