Skip to content

Conversation

@t-vi
Copy link
Collaborator

@t-vi t-vi commented Apr 27, 2018

Sometimes, people are surprised that things cannot be differentiated w.r.t. integer parameters such indices.
The following patch takes some steps to prevent them from requiring gradients of non-floating point Tensors:

  • by setting tensor.requires_grad = True
  • by using tensor.requires_grad_(True)
  • by using factory functions with requires_grad=True

Of course, applying the above with False needs to still be allowed.

This as requested by Adam in #7021, this is done at the Python interface level.

@t-vi
Copy link
Collaborator Author

t-vi commented Apr 27, 2018

something is up with the pytorch-linux-xenial-py3-clang5-asan test. It seems to hang for 30 minutes in test_multi_drop (test_utils.TestDataLoader).
Is that me or the test?

@t-vi t-vi force-pushed the set_requires_grad_only_for_float branch from c7e2383 to 04fcbd6 Compare April 27, 2018 13:02
@zou3519
Copy link
Contributor

zou3519 commented Apr 27, 2018

@t-vi Probably just the test, I've seen intermittent timeouts there

@t-vi
Copy link
Collaborator Author

t-vi commented Apr 27, 2018 via email

for f in [f1, f2, f3]:
a = torch.ones(1, dtype=dt, device='cuda' if cuda else 'cpu')
if dt.is_floating_point:
f()

This comment was marked as off-topic.

@t-vi
Copy link
Collaborator Author

t-vi commented Apr 28, 2018

So, now the MacOS has a "CI changed" failure, but I think it works.

@ezyang
Copy link
Contributor

ezyang commented Apr 30, 2018

OS X failure is unrelated, and I think fixed on master. @pytorchbot retest this please

@apaszke apaszke merged commit 8fbab83 into pytorch:master Apr 30, 2018
@apaszke
Copy link
Contributor

apaszke commented Apr 30, 2018

Thanks @t-vi!

@gchanan
Copy link
Contributor

gchanan commented May 2, 2018

This didn't change the constructors in tensor_new.cpp, e.g. torch.tensor.

If you implemented this in those constructors, it would get a little awkward when combined with type inference, because you don't know the type of the tensor that will come out, e.g.:

def convert_to_tensors(data0, data1)
  return torch.tensor(data0, requires_grad=True), torch.tensor(data1, requires_grad=True)

would not throw an error on `convert_to_tensors([0., 1.], [2., 3.]) but would on convert_to_tensors([0., 1.], [2, 3]). Sometimes you want this fail-fast behavior, but sometimes not.

justindujardin added a commit to justindujardin/thinc that referenced this pull request Jan 31, 2020
Only float tensors can be backpropped and PyTorch throws an error if you try: pytorch/pytorch#7034
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants