Skip to content

Commit 78e3259

Browse files
albanDezyang
authored andcommitted
Add autograd automatic anomaly detection (#7677)
* add autograd automatic anomaly detection * python 3 string support * Fix non python build * fix typo in doc * better test and naming fix * fix no python build and python object handling * fix missing checks * clean NO_PYTHON build * Remove unwanted changes
1 parent 38362fa commit 78e3259

20 files changed

+344
-3
lines changed

docs/source/autograd.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,10 @@ and nvprof based (registers both CPU and GPU activity) using
9898
:members:
9999

100100
.. autofunction:: torch.autograd.profiler.load_nvprof
101+
102+
Anomaly detection
103+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
104+
105+
.. autoclass:: detect_anomaly
106+
107+
.. autoclass:: set_detect_anomaly

setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -737,6 +737,8 @@ def run(self):
737737
"torch/csrc/autograd/init.cpp",
738738
"torch/csrc/autograd/aten_variable_hooks.cpp",
739739
"torch/csrc/autograd/grad_mode.cpp",
740+
"torch/csrc/autograd/anomaly_mode.cpp",
741+
"torch/csrc/autograd/python_anomaly_mode.cpp",
740742
"torch/csrc/autograd/engine.cpp",
741743
"torch/csrc/autograd/function.cpp",
742744
"torch/csrc/autograd/variable.cpp",

test/test_autograd.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from torch.autograd.profiler import profile
1616
from common import TEST_MKL, TestCase, run_tests, skipIfNoLapack, \
1717
suppress_warnings
18-
from torch.autograd import Variable, Function
18+
from torch.autograd import Variable, Function, detect_anomaly
1919
from torch.autograd.function import InplaceFunction
2020
from torch.testing import make_non_contiguous, randn_like
2121

@@ -2306,6 +2306,41 @@ def test_rnn_backward_to_input_but_not_parameters_cuda(self):
23062306
out.sum().backward()
23072307
self.assertFalse(s.grad is None or s.grad.abs().sum().item() == 0)
23082308

2309+
def test_anomaly_detect_nan(self):
2310+
size = 10
2311+
2312+
class MyFunc(Function):
2313+
@staticmethod
2314+
def forward(ctx, inp1, inp2, fail_0th):
2315+
ctx.fail_0th = fail_0th
2316+
return inp1.sum(0, keepdim=True)
2317+
2318+
@staticmethod
2319+
def backward(ctx, gO):
2320+
gI = gO.clone().expand(size)
2321+
gI[0] = 0
2322+
gI[0] /= 0 # Generate a nan
2323+
if ctx.fail_0th:
2324+
return gI, None, None
2325+
else:
2326+
return None, gI, None
2327+
2328+
inp = torch.rand(size, requires_grad=True)
2329+
out = MyFunc.apply(inp, inp, True)
2330+
out.backward() # Should not fail
2331+
2332+
inp = torch.rand(size, requires_grad=True)
2333+
out = MyFunc.apply(inp, inp, True)
2334+
with self.assertRaisesRegexp(RuntimeError, "Function 'MyFuncBackward' returned nan values in its 0th output."):
2335+
with detect_anomaly():
2336+
out.backward()
2337+
2338+
inp = torch.rand(size, requires_grad=True)
2339+
out = MyFunc.apply(inp, inp, False)
2340+
with self.assertRaisesRegexp(RuntimeError, "Function 'MyFuncBackward' returned nan values in its 1th output."):
2341+
with detect_anomaly():
2342+
out.backward()
2343+
23092344

23102345
def index_variable(shape, max_indices):
23112346
if not isinstance(shape, tuple):

tools/cpp_build/libtorch/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ set(TORCH_SRCS
201201
${TORCH_SRC_DIR}/csrc/autograd/profiler.cpp
202202
${TORCH_SRC_DIR}/csrc/autograd/saved_variable.cpp
203203
${TORCH_SRC_DIR}/csrc/autograd/grad_mode.cpp
204+
${TORCH_SRC_DIR}/csrc/autograd/anomaly_mode.cpp
204205
${TORCH_SRC_DIR}/csrc/autograd/function.cpp
205206
${TORCH_SRC_DIR}/csrc/autograd/input_buffer.cpp
206207
${TORCH_SRC_DIR}/csrc/autograd/functions/utils.cpp

torch/autograd/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from .function import Function, NestedIOFunction
1212
from .gradcheck import gradcheck, gradgradcheck
1313
from .grad_mode import no_grad, enable_grad, set_grad_enabled
14+
from .anomaly_mode import detect_anomaly, set_detect_anomaly
1415
from . import profiler
1516

1617
__all__ = ['Variable', 'Function', 'backward', 'grad_mode']

torch/autograd/anomaly_mode.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import torch
2+
3+
4+
class detect_anomaly(object):
5+
r"""Context-manager that enable anomaly detection for the autograd engine.
6+
7+
This does two things:
8+
- Running the forward pass with detection enabled will allow the backward
9+
pass to print the traceback of the forward operation that created the failing
10+
backward function.
11+
- Any backward computation that generate "nan" value will raise an error.
12+
13+
Example:
14+
15+
>>> import torch
16+
>>> from torch import autograd
17+
>>> class MyFunc(autograd.Function):
18+
... @staticmethod
19+
... def forward(ctx, inp):
20+
... return inp.clone()
21+
... @staticmethod
22+
... def backward(ctx, gO):
23+
... # Error during the backward pass
24+
... raise RuntimeError("Some error in backward")
25+
... return gO.clone()
26+
>>> def run_fn(a):
27+
... out = MyFunc.apply(a)
28+
... return out.sum()
29+
>>> inp = torch.rand(10, 10, requires_grad=True)
30+
>>> out = run_fn(inp)
31+
>>> out.backward()
32+
Traceback (most recent call last):
33+
File "<stdin>", line 1, in <module>
34+
File "/your/pytorch/install/torch/tensor.py", line 93, in backward
35+
torch.autograd.backward(self, gradient, retain_graph, create_graph)
36+
File "/your/pytorch/install/torch/autograd/__init__.py", line 90, in backward
37+
allow_unreachable=True) # allow_unreachable flag
38+
File "/your/pytorch/install/torch/autograd/function.py", line 76, in apply
39+
return self._forward_cls.backward(self, *args)
40+
File "<stdin>", line 8, in backward
41+
RuntimeError: Some error in backward
42+
>>> with autograd.detect_anomaly():
43+
... inp = torch.rand(10, 10, requires_grad=True)
44+
... out = run_fn(inp)
45+
... out.backward()
46+
Traceback of forward call that caused the error:
47+
File "tmp.py", line 53, in <module>
48+
out = run_fn(inp)
49+
File "tmp.py", line 44, in run_fn
50+
out = MyFunc.apply(a)
51+
Traceback (most recent call last):
52+
File "<stdin>", line 4, in <module>
53+
File "/your/pytorch/install/torch/tensor.py", line 93, in backward
54+
torch.autograd.backward(self, gradient, retain_graph, create_graph)
55+
File "/your/pytorch/install/torch/autograd/__init__.py", line 90, in backward
56+
allow_unreachable=True) # allow_unreachable flag
57+
File "/your/pytorch/install/torch/autograd/function.py", line 76, in apply
58+
return self._forward_cls.backward(self, *args)
59+
File "<stdin>", line 8, in backward
60+
RuntimeError: Some error in backward
61+
62+
"""
63+
64+
def __init__(self):
65+
self.prev = torch.is_anomaly_enabled()
66+
67+
def __enter__(self):
68+
torch.set_anomaly_enabled(True)
69+
70+
def __exit__(self, *args):
71+
torch.set_anomaly_enabled(self.prev)
72+
return False
73+
74+
75+
class set_detect_anomaly(object):
76+
r"""Context-manager that sets the anomaly detection for the autograd engine on or off.
77+
78+
``set_detect_anomaly`` will enable or disable the autograd anomaly detection
79+
based on its argument :attr:`mode`.
80+
It can be used as a context-manager or as a function.
81+
82+
See ``detect_anomaly`` above for details of the anomaly detection behaviour.
83+
84+
Arguments:
85+
mode (bool): Flag whether to enable anomaly detection (``True``),
86+
or disable (``False``).
87+
88+
"""
89+
90+
def __init__(self, mode):
91+
self.prev = torch.is_anomaly_enabled()
92+
torch.set_anomaly_enabled(mode)
93+
94+
def __enter__(self):
95+
pass
96+
97+
def __exit__(self, *args):
98+
torch.set_anomaly_enabled(self.prev)
99+
return False
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#include "torch/csrc/autograd/anomaly_mode.h"
2+
3+
namespace torch { namespace autograd {
4+
5+
bool AnomalyMode::_enabled = 0;
6+
7+
}}

torch/csrc/autograd/anomaly_mode.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#pragma once
2+
3+
namespace torch { namespace autograd {
4+
5+
struct AnomalyMode {
6+
static bool is_enabled() {
7+
return _enabled;
8+
}
9+
static void set_enabled(bool enabled) {
10+
_enabled = enabled;
11+
}
12+
13+
private:
14+
static bool _enabled;
15+
};
16+
17+
18+
struct AnomalyMetadata {
19+
virtual void store_stack() = 0;
20+
virtual void print_stack() = 0;
21+
};
22+
23+
}}

torch/csrc/autograd/engine.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "torch/csrc/autograd/function.h"
44
#include "torch/csrc/autograd/functions/basic_ops.h"
55
#include "torch/csrc/autograd/grad_mode.h"
6+
#include "torch/csrc/autograd/anomaly_mode.h"
67
#include "torch/csrc/autograd/variable.h"
78
#include "torch/csrc/utils/auto_gpu.h"
89

@@ -269,6 +270,9 @@ auto Engine::thread_main(GraphTask *graph_task) -> void {
269270
auto Engine::thread_on_exception(FunctionTask& task, std::exception& e) -> void {
270271
std::lock_guard<std::mutex> lock(task.base->mutex);
271272
if (!task.base->has_error.load()) {
273+
if (AnomalyMode::is_enabled()) {
274+
task.fn->metadata()->print_stack();
275+
}
272276
task.base->exception = std::current_exception();
273277
task.base->has_error = true;
274278
}
@@ -373,6 +377,20 @@ auto Engine::evaluate_function(FunctionTask& task) -> void {
373377

374378
int num_outputs = outputs.size();
375379
if (num_outputs == 0) return; // Don't even acquire the mutex
380+
381+
if (AnomalyMode::is_enabled()) {
382+
AutoGradMode grad_mode(false);
383+
for (int i = 0; i < num_outputs; ++i) {
384+
auto& output = outputs[i];
385+
AutoGPU guard(output);
386+
if (output.defined() && output.ne(output).any().toCByte()) {
387+
std::stringstream ss;
388+
ss << "Function '" << fn.name() << "' returned nan values in its " << i << "th output.";
389+
throw std::runtime_error(ss.str());
390+
}
391+
}
392+
}
393+
376394
std::lock_guard<std::mutex> lock(task.base->mutex);
377395
for (int i = 0; i < num_outputs; ++i) {
378396
auto& output = outputs[i];

torch/csrc/autograd/engine.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
#include "torch/csrc/autograd/function.h"
77
#include "torch/csrc/autograd/input_buffer.h"
8+
#include "torch/csrc/autograd/anomaly_mode.h"
89

910
#include <deque>
1011
#include <exception>
@@ -41,6 +42,9 @@ struct Engine {
4142
bool keep_graph,
4243
bool create_graph,
4344
const edge_list& outputs = {});
45+
virtual std::unique_ptr<AnomalyMetadata> make_anomaly_metadata() {
46+
return nullptr;
47+
}
4448

4549
void queue_callback(std::function<void()> callback);
4650

0 commit comments

Comments
 (0)