Skip to content

Conversation

@ssnl
Copy link
Collaborator

@ssnl ssnl commented Apr 27, 2018

addresses #7002

Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm. Add a test maybe (and comments!), if someconfused soul reverts this?

self: grad * ((self >= min) * (self <= max)).type_as(grad)

- name: clamp_min(Tensor self, Scalar min)
self: grad * (self > min).type_as(grad)

This comment was marked as off-topic.

@ssnl
Copy link
Collaborator Author

ssnl commented Apr 28, 2018

@zou3519 Updated clamp_min and clamp_max as well. However, I'm not sure if we should add a test for this as it is mathematically correct either way. @apaszke , any ideas?

@ezyang
Copy link
Contributor

ezyang commented Apr 30, 2018

Yes, this subtle formula tweak definitely merits a comment. A link to the issue is a good start, and maybe something like "Gradient is not defined at the boundaries, but empirically it's helpful to be able to get gradient on min and max."

EDIT: And perhaps it points its way to a more general principle: if you need to return the gradient at a nondifferentiable point, and you have lim x->0+ f'(x) = 0 and lim x->0- f'(x) != 0, you should always take the nonzero gradient.

Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree a comment would be nice. @ezyang I don't think we should commit to that principle. We could use it to guide some choices, but I don't want to have ops that will have to do an extra pass to check for boundary conditions and fill the gradients if it's expensive.

@ssnl ssnl merged commit d9aeb7e into pytorch:master Apr 30, 2018
@ssnl ssnl deleted the clamp_g branch April 30, 2018 13:22
Jorghi12 pushed a commit to wsttiger/pytorch that referenced this pull request May 10, 2018
* subgradient 1 at min and max for clamp

* clamp max and clamp min too

* add comment
weiyangfb pushed a commit to weiyangfb/pytorch that referenced this pull request Jun 11, 2018
* subgradient 1 at min and max for clamp

* clamp max and clamp min too

* add comment
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants