-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Update torch.set_default_dtype doc #41263
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
torch/__init__.py
Outdated
| Example:: | ||
| >>> torch.tensor([1.2, 3]).dtype # initial default for floating point is torch.float32 | ||
| >>> # initial default for floating point is torch.float |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this isn't a huge deal, but I think the previous wording was slightly better, because it matches what is returned for the next line. As it is now, people will wonder why torch.float and torch.float32 don't match exactly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ahh ok
torch/__init__.py
Outdated
| >>> # initial default for floating point is torch.cfloat | ||
| >>> torch.tensor([1.2, 3j]).dtype | ||
| torch.complex64 | ||
| >>> torch.set_default_tensor_type(torch.DoubleTensor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we want to document this -- the type interface isn't complete (e.g. there's no quantized, XLA) -- it's really logically deprecated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah that makes sense: updated to torch.set_default_dtype(torch.float64)
torch/__init__.py
Outdated
| used as default floating point type for type inference in | ||
| :func:`torch.tensor`. | ||
| r"""Sets the default floating point dtype to :attr:`d`. | ||
| This type is also used: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this says "also used" but it's sort of defined by how it's used, i.e. "sets the default floating point dtype" doesn't really mean anything outside of how it' s used. I'd just change this to something like:
"This dtype is used as:"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated
torch/__init__.py
Outdated
| 1. As default floating point type for type inference in :func:`torch.tensor`. | ||
| 2. To determine the default complex dtype. The default complex dtype is set to | ||
| ``torch.complex128`` if default floating point tensor type is ``torch.float64``, | ||
| else it's set to ``torch.complex64`` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: else -> otherwise.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it not true that this is the corresponding complex dtype for the float default?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's true
torch/__init__.py
Outdated
| :func:`torch.tensor`. | ||
| r"""Sets the default floating point dtype to :attr:`d`. | ||
| This type is also used: | ||
| 1. As default floating point type for type inference in :func:`torch.tensor`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this could be a little clearer (I know it's coming from the prior writeup). But "default floating point type" isn't really defined yet. Maybe something like:
"As the dtype inferred for python floats" (and similar for complex).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated
[ghstack-poisoned]
torch/__init__.py
Outdated
| r"""Sets the default floating point dtype to :attr:`d`. | ||
| This dtype is: | ||
| 1. The inferred dtype for python floats in :func:`torch.tensor`. | ||
| 2. Used to determine the default complex dtype. The default complex dtype is set to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same issue with 1 -- this should say something like "the inferred dtype for python complex numbers"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this sound better?
Used to infer dtype for python complex numbers. The default complex dtype is set to
``torch.complex128`` if default floating point dtype is ``torch.float64``, otherwise it's set to ``torch.complex64``
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm!
| 1. The inferred dtype for python floats in :func:`torch.tensor`. | ||
| 2. Used to determine the default complex dtype. The default complex dtype is set to | ||
| ``torch.complex128`` if default floating point tensor type is ``torch.float64``, | ||
| otherwise it's set to ``torch.complex64`` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm curious: why did you word it this way instead of saying something about the corresponding float dtype?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The only reason why I didn't it frame it that way is because we don't have a complex dtype corresponding to torch.bfloat16, torch.float16 (technically we do but the support for ComplexHalf is almost non-existing)
[ghstack-poisoned]
💊 CI failures summary and remediationsAs of commit 2030496 (more details on the Dr. CI page):
ci.pytorch.org: 2 failed
This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 1 time. |
| This dtype is: | ||
| 1. The inferred dtype for python floats in :func:`torch.tensor`. | ||
| 2. Used to infer dtype for python complex numbers. The default complex dtype is set to | ||
| ``torch.complex128`` if default floating point dtype is ``torch.float64``, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"if the default..."
| :func:`torch.tensor`. | ||
| r"""Sets the default floating point dtype to :attr:`d`. | ||
| This dtype is: | ||
| 1. The inferred dtype for python floats in :func:`torch.tensor`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't quite right since it's the inferred type for Python floats in other contexts, too:
t = torch.tensor(5)
torch.set_default_dtype(torch.double)
(t + 5.).dtype
: torch.float64
torch.full((2,), 2.).dtype
: torch.float64
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this falls under the torch.tensor case since when you add a scalar to a tensor, we first make a zero dim tensor containing the scalar (hence the dtype of this new tensor in this case will be torch.float64
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's true in the first case but also an implementation detail I wouldn't expect a user to be be aware of. In the second case I don't think torch.tensor() is involved?
| r"""Sets the default floating point dtype to :attr:`d`. | ||
| This dtype is: | ||
| 1. The inferred dtype for python floats in :func:`torch.tensor`. | ||
| 2. Used to infer dtype for python complex numbers. The default complex dtype is set to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You may want to make this a note for clarity:
"This setting also determines the inferred dtype of complex numbers. If the default floating point dtype is torch.float64 then complex numbers are inferred to have a dtype of torch.complex128, otherwise they are assumed to have a dtype of torch.complex64."
| ``torch.complex128`` if default floating point dtype is ``torch.float64``, | ||
| otherwise it's set to ``torch.complex64`` | ||
| The default floating point dtype is initially ``torch.float32``. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might also be noteworthy?
| Example:: | ||
| >>> torch.tensor([1.2, 3]).dtype # initial default for floating point is torch.float32 | ||
| >>> # initial default for floating point is torch.float32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Following the above comment, maybe it'd be nice to have an example that isn't torch.tensor based?
|
@anjali411 merged this pull request in e888c3b. |
Summary: Pull Request resolved: #41263 Test Plan: Imported from OSS Differential Revision: D22482989 Pulled By: anjali411 fbshipit-source-id: 2aadfbb84bbab66f3111970734a37ba74d817ffd
Stack from ghstack:
Differential Revision: D22482989