Skip to content

Conversation

@mruberry
Copy link
Collaborator

Currently compare_with_numpy requires a device and dtype, but these arguments are ignored if a tensor is provided. This PR updates the function to only take device and dtype if a tensor-like object is given. This should prevent confusion that you could, for example, pass a CPU float tensor but provided a CUDA device and integer dtype.

Several tests are updated to reflect this behavior.

@mruberry mruberry added module: tests Issues related to tests (not the torch.testing module) module: numpy Related to numpy support, and also numpy compatibility of our operators labels Jun 16, 2020
@mruberry mruberry requested a review from ngimel June 16, 2020 00:03
Copy link
Collaborator

@ngimel ngimel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's good we are getting rid of tolist()s.

# which takes care of negative strides if present.
torch_fn, np_fn = funcs
self.compare_with_numpy(torch_fn, np_fn, data, device, dtype)
if dtype.is_floating_point or dtype.is_complex:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so this ignored device altogether? Nice.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. Historically it would create a list from a tensor and then compare_with_numpy would put that list into a tensor on the appropriate device.

The updated test just has fewer steps to get to the same place.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@mruberry merged this pull request in ebd8691.

@mruberry mruberry deleted the compare_with_numpy_clarity branch June 18, 2020 04:32
xwang233 pushed a commit to xwang233/pytorch that referenced this pull request Jun 20, 2020
Summary:
Currently compare_with_numpy requires a device and dtype, but these arguments are ignored if a tensor is provided. This PR updates the function to only take device and dtype if a tensor-like object is given. This should prevent confusion that you could, for example, pass a CPU float tensor but provided a CUDA device and integer dtype.

Several tests are updated to reflect this behavior.
Pull Request resolved: pytorch#40064

Differential Revision: D22058072

Pulled By: mruberry

fbshipit-source-id: b494bb759855977ce45b79ed3ffb0319a21c324c
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: numpy Related to numpy support, and also numpy compatibility of our operators module: tests Issues related to tests (not the torch.testing module)

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants