Skip to content

Conversation

@vishwakftw
Copy link
Contributor

Defined the subgradient of std when result = 0 to be inf.

  1. Added tests in test_autograd
  2. New method std_backward() for computing backward in Functions.cpp

Fixes #4320 .

1. Added tests in test_autograd
2. New method std_backward() for computing backward in Functions.cpp
@ssnl
Copy link
Collaborator

ssnl commented Jul 7, 2018

:( this adds 2 extra ops

@vishwakftw
Copy link
Contributor Author

vishwakftw commented Jul 7, 2018

I can make it one extra op alone, by modifying the entry in derivatives.yaml alone.


- name: std(Tensor self, bool unbiased)
self: std_backward(grad, self, result, unbiased)
self: var_backward(grad / (2 * result), self, unbiased).masked_fill_(result == 0., INFINITY)

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@ezyang
Copy link
Contributor

ezyang commented Jul 9, 2018

@vishwakftw I'm perhaps not the best core dev to ask about this, since I have always leaned on the side of correctness versus perf, but maybe it would help appease fears about slowness to do a quick benchmark before and after to see what the perf impact is.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ssnl
Copy link
Collaborator

ssnl commented Jul 9, 2018

while we are at it, why do we choose inf? if i understand correctly, any value in [-inf, inf] should be fine, right?

@vishwakftw
Copy link
Contributor Author

This is the time difference:

# Without masked_fill_(result == 0.0, INFINITY)
81 ms ± 1.71 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

# With masked_fill_(result == 0.0, INFINITY)
128 ms ± 2.65 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

# With recent version (optimized masked_fill)
109 ms ± 1.87 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

Regarding use of inf instead of of any other value in -inf, inf, I did this to match the behaviour of sqrt.

Copy link
Member

@colesbury colesbury left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inf is not a good choice for the gradient of std() when std() is 0.

You can derive a formula using the limit definition of a derivative and the formula for standard deviation. It's going to depend on the number of elements N.

@colesbury
Copy link
Member

I think the formula should be:

unbiased: 1/sqrt(n)
biased: sqrt(n-1)/n

You can verify this numerically

N = 7
eps = 1e-9
x = torch.zeros(N, dtype=torch.double)
x[0] = eps
print(float(torch.std(x) / eps))
print(float(1/math.sqrt(N)))

@vishwakftw
Copy link
Contributor Author

vishwakftw commented Jul 12, 2018

@colesbury Thanks for the advice. I computed the derivative using the definition and it came out to be:
sqrt((n-1)/nN) where N = n-1 for unbiased and n otherwise.

@ssnl
Copy link
Collaborator

ssnl commented Jul 12, 2018 via email

@ssnl
Copy link
Collaborator

ssnl commented Jul 12, 2018 via email

@vishwakftw
Copy link
Contributor Author

It won't be negative, but I think it'll be higher if you take x - delta instead of x + delta.

@ssnl
Copy link
Collaborator

ssnl commented Jul 12, 2018 via email

@vishwakftw
Copy link
Contributor Author

@pytorchbot retest this please

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.


Tensor std_backward(const Tensor & grad, const Tensor & self, Tensor result, bool unbiased) {
Tensor result_zero_mask = (result == 0.);
if (result_zero_mask.any().toCByte()) {

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@ezyang
Copy link
Contributor

ezyang commented Jul 14, 2018

@ssnl is the math good now?

@ssnl
Copy link
Collaborator

ssnl commented Jul 15, 2018

Derivative is not the same if we take the limit from x->0- direction. The change in the output is the same, but the change in the input is negated. So the limit is negative of that taken from x->0+ direction. You can verify this using @colesbury 's code above but with eps=-1e-9. So essentially anywhere in range [-1/sqrt(N), 1/sqrt(N)] is valid subgradient.

Since the usage is mostly gradient-based optimization like gd, it makes sense to take one of these two extremes. But it's unclear to me that which we should choose. (nan really isn't a bad choice in this sense.) I think if we choose one, we should state the choice in the doc.

@ezyang
Copy link
Contributor

ezyang commented Jul 17, 2018

@vishwakftw?

@vishwakftw
Copy link
Contributor Author

I thought we were waiting for @colesbury ‘s call. If not, I am sorry, I will send in the changes as soon as I can.

@ezyang
Copy link
Contributor

ezyang commented Jul 17, 2018

No, I'm sorry; I wasn't sure what we were waiting on. If you like I can bug him about it tomorrow.

@vishwakftw
Copy link
Contributor Author

@ezyang Here are the timings you were curious about:

With CUDA
With unconditional masked_fill_ in std_backward() for tensor of size 1000 x 1000 x 10 filled with 1s
33.6 ms ± 173 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

With conditional masked_fill_ in std_backward() for tensor of size 1000 x 1000 x 10 filled with 1s
33.7 ms ± 275 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

With unconditional masked_fill_ in std_backward() for tensor of size 1000 x 1000 x 10 filled using randn
31.2 ms ± 111 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

With conditional masked_fill_ in std_backward() for tensor of size 1000 x 1000 x 10 filled using randn
30.8 ms ± 209 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

conditional masked_fill_ is the current implementation.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ailzhang
Copy link
Contributor

@pytorchbot retest this please

@vishwakftw
Copy link
Contributor Author

Is this good to go?

@colesbury
Copy link
Member

I'm a bit wary about this change. It looks like it introduces a CUDA synchronization point in result_zero_mask.any().toCByte() (as ezyang pointed out).

synchronization points may not slow down an individual call, but they make it harder to hide kernel launch latency and CPU computation in a bigger system.

@vishwakftw
Copy link
Contributor Author

Are you suggesting that the masked_fill_ be done unconditionally?

@ezyang
Copy link
Contributor

ezyang commented Aug 13, 2018

I hate to say it, but maybe we should just write the fused kernel that does the masked fill unconditionally. Then we don't have to worry about the extra kernel launch overhead.

@YannDubs
Copy link

YannDubs commented Sep 24, 2018

@vishwakftw @colesbury not sure if the derivative you computed are correct. It seems to me that the formulas you have were derived adding only eps to one dimension, but what we really want is to use a eps at each dimensions. Indeed, we only get nans when all the xi's are constant (you don't have 0/0 in the other case).

I.e what you have :

N = 7
eps = 1e-9
x0 = torch.zeros(N, dtype=torch.double, requires_grad=True)
epsis = torch.zeros(N, dtype=torch.double)
epsis[0] = eps
x = x0 + epsis
y=torch.std(x)
y.backward()
print(x0.grad)
# tensor([ 0.3780, -0.0630, -0.0630, -0.0630, -0.0630, -0.0630, -0.0630],  dtype=torch.float64)

what I think we want to solve:

N = 7
eps = 1e-9
x0 = torch.zeros(N, dtype=torch.double, requires_grad=True)
x = x0 +eps
y=torch.std(x)
y.backward()
print(x0.grad)
# tensor([nan, nan, nan, nan, nan, nan, nan], dtype=torch.float64)

For the second point you have to add eps to each xi's in the definition of the derivative. The limit should give 0.

Numerical check:

N = 7
eps = 1e-9
x = torch.zeros(N, dtype=torch.double)
print(float(torch.std(x+eps) / eps))
# 0.0

I hope I understood correctly the issue.

@vishwakftw
Copy link
Contributor Author

This is stale, closing for now.

@vishwakftw vishwakftw closed this Nov 14, 2018
@vishwakftw vishwakftw deleted the std-gradient-fix branch November 14, 2018 03:20
tony2037 pushed a commit to tony2037/attention-is-all-you-need-pytorch that referenced this pull request May 27, 2020
According to "attention is all you need", the formula
apply to sublayer is supposed to be LayerNorm(x + Sublayer(x)).

As mentioned in
[issue#142](jadore801120#142),
this implementation results from the consideration of the problem in pytorch(see [pytorch
issue#4320](pytorch/pytorch#4320)), which has
been fix in [Fix standard deviation gradient](pytorch/pytorch#9238).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

torch.std NaN gradient

8 participants