Skip to content

Conversation

@ssnl
Copy link
Collaborator

@ssnl ssnl commented May 19, 2018


test()
with torch.no_grad():
test()

This comment was marked as off-topic.

if instance is None:
return self
value = self.wrapped(instance)
with torch.enable_grad():

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@apaszke
Copy link
Contributor

apaszke commented May 20, 2018

cc @fritzo (sorry, didn't notice you're already tagged)

@fritzo
Copy link
Collaborator

fritzo commented May 20, 2018

cc @neerajprad

It seems a little forceful to always enable grad. One option that seems consistent to me is to inherit grad_enabled from its setting at initialization time

class Distribution(object):
    def __init__(self, ...):
        ...
        self._grad_enabled = torch.autograd.get_grad_enabled()  # what is the actual syntax?

class lazy_property(object):
    ...
    def __get__(self, instance, obj_type=None):
        if instance is None:
            return self
        with torch.autograd.set_grad_enabled(instance._grad_enabled):
            value = self.wrapped(instance)
        setattr(instance, self.wrapped.__name__, value)
        return value

However it's unclear whether we should be clever or whether we should simply require strict usage: "grad_enabled should have a single value throughout the lifetime of a distribution object". The distributions are intended to be flyweight objects and to be cheap to reconstruct in each grad_enabled context.

@apaszke
Copy link
Contributor

apaszke commented May 20, 2018

I think the "constant grad mode" invariant is a bit too easy to get wrong. Forcing grad mode doesn't seem that bad really. If your distribution parameters don't require grad then it will be a no-op. Otherwise it's quite likely that you will be interested in differentiating those parts, and I think it's better to trade off some memory for compute in this case.

It's a hard problem, because you don't have any information about how the object will be used. You either need to drop the autograd history, hoping that someone who calls .sample() will never need it, or you need to just live with the fact that you might end up wasting some memory, because we temporarily enabled grad. Provided the distribution is short-lived the downside is somewhat irrelevant though, because the properties should get removed quickly. If we find out that this strategy doesn't work too well, we can always add an extra constructor parameter that lets you choose a strategy for caching those.

@ssnl
Copy link
Collaborator Author

ssnl commented May 24, 2018

@fritzo @apaszke @neerajprad have we reached consensus on whether we should merge this? :)

@fritzo
Copy link
Collaborator

fritzo commented May 24, 2018

This seems reasonable to me (I'm convinced by @apaszke).

@ssnl ssnl merged commit c946db1 into pytorch:master May 24, 2018
@ssnl ssnl deleted the grad_lazy_prop branch May 24, 2018 15:22
weiyangfb pushed a commit to weiyangfb/pytorch that referenced this pull request Jun 11, 2018
…torch#7708)

* Always enable grad when calculating lazy_property

* Add test with MultiVariableNormal
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[distributions] rsample().detach() and sample() yields different gradients

5 participants