-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🐛 Bug
(pardon my posts from a few days back about memory leak, I'm in the middle of hunting memory leaks in my applications)
I extended PyTorch with torch.autograd.Function to write the gradient of some functionals (i.e. functions that receives other functions as their inputs), but I found a memory leak when there is a circular reference with the functionals.
To Reproduce
import torch
def functional(fcn, y0):
params = fcn.__self__.parameters() # assuming fcn is a method of `torch.nn.Module`
return Functional.apply(fcn, y0, *params) # NO_MEMLEAK_IF: params is an empty list
class Functional(torch.autograd.Function):
@staticmethod
def forward(ctx, fcn, y0, *params):
y = fcn(y0)
ctx.fcn = fcn # NO_MEMLEAK_IF: removing this line, but fcn is needed in backward
return y
class DummyModule(torch.nn.Module):
def __init__(self, a):
super().__init__()
self.a = torch.nn.Parameter(a)
x0 = torch.ones_like(a)
xsol = functional(self.forward, x0) # NO_MEMLEAK_IF: changing this line to self.forward(x0)
self.xsol = xsol # NO_MEMLEAK_IF: removing this line or using xsol.detach()
def forward(self, x):
return self.a * x
def test_functional():
a = torch.ones((200000000,), dtype=torch.double, device=torch.device("cuda"))
model = DummyModule(a)
for i in range(5):
test_functional()
torch.cuda.empty_cache()
print("memory allocated:", float(torch.cuda.memory_allocated() / (1024 ** 2)), "MiB")When I run the code above, I got
$ python memtest.py
memory allocated: 3052.0 MiB
memory allocated: 6104.0 MiB
memory allocated: 9156.0 MiB
Traceback (most recent call last):
File "memtest.py", line 34, in <module>
test_functional()
File "memtest.py", line 31, in test_functional
model = DummyModule(a)
File "memtest.py", line 20, in __init__
x0 = torch.ones_like(a)
RuntimeError: CUDA out of memory. Tried to allocate 1.49 GiB (GPU 0; 11.91 GiB total capacity; 10.43 GiB already allocated; 918.25 MiB free; 10.43 GiB reserved in total by PyTorch)
Expected behavior
No memory leak should happen.
In the example code above, I put a comment NO_MEMLEAK_IF to indicate how to remove the memory leak.
Only one NO_MEMLEAK_IF should be followed to remove the leak.
However, the memory leak should not appear even if the code is like above.
Environment
- PyTorch Version (e.g., 1.0): 1.8.0a0+4803eaf
- OS (e.g., Linux): Linux
- How you installed PyTorch (
conda,pip, source): source - Build command you used (if compiling from source):
python setup.py install - Python version: 3.8.5
- CUDA/cuDNN version: 11.1
- GPU models and configuration: TITAN X
- Any other relevant information: -
Additional context
I cannot do ctx.save_for_backward(fcn) because fcn is a method, not a tensor.
cc @ezyang @albanD @zou3519 @gqchen @pearu @nikitaved @soulitzer @jlin27 @mruberry