Skip to content

Conversation

@Naman-ntc
Copy link
Contributor

Added randint functionality as requested in [#5874].

@ssnl
Copy link
Collaborator

ssnl commented Mar 30, 2018

This needs documentation in _torch_docs.py and tests in test_torch.py.


Tensor randint(const Type& dtype, IntList size, Generator* generator) {
Tensor result = dtype.tensor(size);
return result.random_(0, 1, generator);

This comment was marked as off-topic.

This comment was marked as off-topic.

return result;
}

Tensor randint(const Type& dtype, IntList size, int64_t low, int64_t high, Generator* generator) {

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

- func: randperm_out(Tensor result, int64_t n, *, Generator* generator=nullptr) -> Tensor
variants: function

- func: randint(Type dtype, IntList size, int64_t low, int64_t high, *, Generator* generator=nullptr) -> Tensor

This comment was marked as off-topic.

This comment was marked as off-topic.

@Naman-ntc
Copy link
Contributor Author

Naman-ntc commented Mar 31, 2018

These are the behaviours on corner cases otherwise working fine :

=>torch.randint(5,(3))
Traceback (most recent call last):
File "", line 1, in
TypeError: randint(): argument 'size' (position 2) must be tuple of ints, not int

=>torch.randint(5,3)
Traceback (most recent call last):
File "", line 1, in
TypeError: randint(): argument 'size' (position 2) must be tuple of ints, not int

=>torch.randint(5,10,3)
Traceback (most recent call last):
File "", line 1, in
TypeError: randint() received an invalid combination of arguments - got (int, int, int), but expected one of:
(int high, tuple of ints size, torch.Generator generator, torch.dtype dtype, int device, bool requires_grad)
(int low, int high, tuple of ints size, torch.Generator generator, torch.dtype dtype, int
device, bool requires_grad)

=>torch.randint(3,5,())
4
[torch.FloatTensor of size ()]

The first outcome is bothering me. Any fix? (Others are fine in my opinion)

@ezyang
Copy link
Contributor

ezyang commented Apr 1, 2018

@pytorchbot test this please

The behavior in the first case is just a Python-ism. (3) is the same as 3; the one element tuple is (3, ).

@Naman-ntc
Copy link
Contributor Author

Naman-ntc commented Apr 1, 2018

@ezyang yeah that case is working (I should have included it in my examples). Also torch.randn((3)) works right, so therefore torch.randint((3,5,(4)) must also work for sake of continuity right?

EDIT :
Some additional inputs

torch.randint(6,22,(3,))
7
21
6
[torch.FloatTensor of size (3,)]

torch.randint(6,22,(2,2))
19 11
18 9
[torch.FloatTensor of size (2,2)]

@ezyang
Copy link
Contributor

ezyang commented Apr 2, 2018

@pytorchbot test this please

@ezyang
Copy link
Contributor

ezyang commented Apr 2, 2018

So, perhaps what you are suggesting is that randint should have an option randint(low, high, *args), where args are sizes?

@Naman-ntc
Copy link
Contributor Author

Naman-ntc commented Apr 2, 2018

@ezyang No we need to pass the size as a tuple only as we provide
randint(5,(2,2)) to give random integers in range 0-5
If we dont enclose size in tuple then we cant differentiate between randint(5,2,(2)) and randint(5,(2,2))

EDIT : I am suggesting to have option for Intlist to accept tuple of size 1 as well, which it somehow does for torch.randn

@Naman-ntc Naman-ntc changed the title added randint function in ATEN yaml as well as Tensorfactories.cpp randint function Apr 2, 2018
@ezyang
Copy link
Contributor

ezyang commented Apr 2, 2018

No, what I am saying is that (3) is not a tuple of size one.

>>> (3)
3
>>> (3,)
(3,)

It accidentally works in torch.randn((3)) case because randn has an overload torch.randn(*size_args).

Best not to get too crazy with the overloads here, IMO.

@Naman-ntc
Copy link
Contributor Author

Naman-ntc commented Apr 2, 2018

@ezyang
Ohh i see!! So i believe currently it's working fine, as we cant provide size_args overloading coz it'll complicate things. Will write the documentation and tests then, and also mention in docs specifically that for tensor of size 3 use sizes as (3,)

By the way if I overload it as torch.randint(int,int,int) then torch.randint(1,5,(3)) would work right. I can do that if you want. (Hacky I know :P )

@ezyang
Copy link
Contributor

ezyang commented Apr 2, 2018

I hate it! Leave it the way it is ;)

@ssnl
Copy link
Collaborator

ssnl commented Apr 2, 2018

Looks reasonable, but this still needs doc and tests.

@Naman-ntc
Copy link
Contributor Author

Will commit by tommorow @ssnl :)

@ezyang ezyang changed the title randint function [WIP] randint function Apr 3, 2018
@Naman-ntc
Copy link
Contributor Author

Naman-ntc commented Apr 4, 2018

I am having doubt how should I document the function.
As I have used function overloading I can create a function signature like this :

torch.randint(low=0,high,sizes) (low can be optional)

Now I think it might confuse some people as (in general) it's not possible for first argument to be optional before second and third arguments (They dont know I used function overloading for the feature).

Do you think it's alright to document it like this??

@ssnl
Copy link
Collaborator

ssnl commented Apr 4, 2018

We already have docs where first arg is written as optional, e.g. http://pytorch.org/docs/master/torch.html#torch.arange. So it should be fine.

@Naman-ntc
Copy link
Contributor Author

@ssnl I guess it's ready to merge now?

@Naman-ntc
Copy link
Contributor Author

@ssnl is there anything more to add or correct now?

@ezyang
Copy link
Contributor

ezyang commented Apr 5, 2018

@ssnl

Copy link
Collaborator

@ssnl ssnl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally looks great except for two minor things.

The shape of the tensor is defined by the variable argument :attr:`sizes`.
Args:
low (int, optional): Lowest (positive) integer to be drawn from the distribution. Default: 0.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

- func: randn_like(Tensor self, *, Type dtype) -> Tensor
variants: function

- func: randint(Type dtype, int64_t high, IntList size, *, Generator* generator=nullptr) -> Tensor

This comment was marked as off-topic.


add_docstr(torch.randint,
r"""
randint(low=0, high, sizes, out=None) -> Tensor

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

Copy link
Collaborator

@ssnl ssnl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, but need some owner to look at this.

high (int): One above the highest integer to be drawn from the distribution.
sizes (tuple): a tuple defining the shape of the output tensor.
out (Tensor, optional): the output tensor
dtype (torch.dpython:type, optional) – the desired type of returned Tensor. Default: torch.float32

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@Naman-ntc
Copy link
Contributor Author

Can we merge this now?

Copy link
Member

@colesbury colesbury left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@colesbury colesbury merged commit acb7df1 into pytorch:master Apr 10, 2018
@vadimkantorov
Copy link
Contributor

@Naman-ntc
Copy link
Contributor Author

It'll probably come in next release!

@vadimkantorov
Copy link
Contributor

I think some autofunction entry is needed, so that Sphinx would pull up the docs from the doc string

@fmassa
Copy link
Member

fmassa commented Apr 11, 2018

Looks like @vadimkantorov is right.
@Naman-ntc can you add an entry for randint in https://github.com/pytorch/pytorch/blob/master/docs/source/torch.rst ? Probably something like

.. autofunction:: randint

in the random functions section, look here

@Naman-ntc
Copy link
Contributor Author

Ohh I see!
Makes sense, will make a new PR for it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants