Skip to content

Conversation

@bertmaher
Copy link
Contributor

@bertmaher bertmaher commented Sep 24, 2020

Stack from ghstack:

Summary: While tracking down a recent memory corruption bug we found that
cuda-memcheck wasn't finding the bad accesses, and @ngimel pointed out that
it's because we use a caching allocator so a lot of "out of bounds" accesses
land in a valid slab.

This PR adds a runtime knob (PYTORCH_CUDA_DEBUG_MEMORY) that, when set,
bypasses the caching allocator's caching logic so that allocations go straight
to cudaMalloc. This way, cuda-memcheck will actually work.

Test Plan: Insert some memory errors and run a test under cuda-memcheck;
observe that cuda-memcheck flags an error where expected.

Specifically I removed the output-masking logic here:
https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/tensorexpr/cuda_codegen.cpp#L819-L826

And ran:

PYTORCH_CUDA_DEBUG_MEMORY=1 cuda-memcheck pytest -k test_superslomo test_jit_fuser_te.py

Differential Revision: D23964734

Summary: While tracking down a recent memory corruption bug we found that
cuda-memcheck wasn't finding the bad accesses, and @ngimel pointed out that
it's because we use a caching allocator so a lot of "out of bounds" accesses
land in a valid slab.

This PR adds a runtime knob (`PYTORCH_CUDA_DEBUG_MEMORY`) that, when set,
bypasses the caching allocator's caching logic so that allocations go straight
to cudaMalloc.  This way, cuda-memcheck will actually work.

Test Plan: Insert some memory errors and run a test under cuda-memcheck;
observe that cuda-memcheck flags an error where expected.

Specifically I removed the output-masking logic here:
https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/tensorexpr/cuda_codegen.cpp#L819-L826

And ran:
```
PYTORCH_CUDA_DEBUG_MEMORY=1 cuda-memcheck pytest -k test_superslomo test_jit_fuser_te.py
```

[ghstack-poisoned]
bertmaher added a commit that referenced this pull request Sep 24, 2020
Summary: While tracking down a recent memory corruption bug we found that
cuda-memcheck wasn't finding the bad accesses, and @ngimel pointed out that
it's because we use a caching allocator so a lot of "out of bounds" accesses
land in a valid slab.

This PR adds a runtime knob (`PYTORCH_CUDA_DEBUG_MEMORY`) that, when set,
bypasses the caching allocator's caching logic so that allocations go straight
to cudaMalloc.  This way, cuda-memcheck will actually work.

Test Plan: Insert some memory errors and run a test under cuda-memcheck;
observe that cuda-memcheck flags an error where expected.

Specifically I removed the output-masking logic here:
https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/tensorexpr/cuda_codegen.cpp#L819-L826

And ran:
```
PYTORCH_CUDA_DEBUG_MEMORY=1 cuda-memcheck pytest -k test_superslomo test_jit_fuser_te.py
```

ghstack-source-id: 5a28c87
Pull Request resolved: #45294
@bertmaher
Copy link
Contributor Author

So this is a pretty minimal debug path for the cuda allocator; I'd love some advice on whether this is a reasonable approach, and whether there's anything else I should be doing here. Also, suggestions for perf testing would be great -- I think we have an operator overhead bench that I'll try out, anything else? It seems like a perf-sensitive path so with any luck it doesn't slow that down.

@codecov
Copy link

codecov bot commented Sep 24, 2020

Codecov Report

Merging #45294 into gh/bertmaher/23/base will increase coverage by 0.00%.
The diff coverage is n/a.

Impacted file tree graph

@@                  Coverage Diff                   @@
##           gh/bertmaher/23/base   #45294    +/-   ##
======================================================
  Coverage                 68.05%   68.05%            
======================================================
  Files                       396      393     -3     
  Lines                     51232    50914   -318     
======================================================
- Hits                      34864    34651   -213     
+ Misses                    16368    16263   -105     
Impacted Files Coverage Δ
torch/distributed/rpc/options.py 33.33% <0.00%> (-50.01%) ⬇️
torch/distributed/rpc/backend_registry.py 32.35% <0.00%> (-16.04%) ⬇️
torch/utils/_benchmark/utils/common.py 77.68% <0.00%> (-13.23%) ⬇️
torch/testing/_internal/common_cuda.py 54.21% <0.00%> (-9.83%) ⬇️
torch/backends/cuda/__init__.py 62.50% <0.00%> (-8.34%) ⬇️
torch/distributed/optim/optimizer.py 29.78% <0.00%> (-7.57%) ⬇️
torch/nn/quantized/modules/conv.py 85.25% <0.00%> (-4.32%) ⬇️
torch/optim/adagrad.py 79.03% <0.00%> (-4.31%) ⬇️
torch/testing/_internal/dist_utils.py 33.33% <0.00%> (-2.11%) ⬇️
torch/onnx/symbolic_opset12.py 25.00% <0.00%> (-1.79%) ⬇️
... and 37 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update e4950a0...7dd3b7c. Read the comment docs.

Copy link
Collaborator

@dzhulgakov dzhulgakov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. Maybe call it PYTORCH_NO_CUDA_MEMORY_CACHING or something like that? 'debugging' implies some fancy tool/report

@ezyang
Copy link
Contributor

ezyang commented Sep 25, 2020

Yeah, this looks fine, but agreed with Dmytro on env var renaming. You should also think about where to document this option.

@ngimel
Copy link
Collaborator

ngimel commented Sep 25, 2020

This section https://pytorch.org/docs/master/cuda.html#memory-management looks like a reasonable place for documentation

…ing"

Summary: While tracking down a recent memory corruption bug we found that
cuda-memcheck wasn't finding the bad accesses, and @ngimel pointed out that
it's because we use a caching allocator so a lot of "out of bounds" accesses
land in a valid slab.

This PR adds a runtime knob (`PYTORCH_CUDA_DEBUG_MEMORY`) that, when set,
bypasses the caching allocator's caching logic so that allocations go straight
to cudaMalloc.  This way, cuda-memcheck will actually work.

Test Plan: Insert some memory errors and run a test under cuda-memcheck;
observe that cuda-memcheck flags an error where expected.

Specifically I removed the output-masking logic here:
https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/tensorexpr/cuda_codegen.cpp#L819-L826

And ran:
```
PYTORCH_CUDA_DEBUG_MEMORY=1 cuda-memcheck pytest -k test_superslomo test_jit_fuser_te.py
```

[ghstack-poisoned]
bertmaher added a commit that referenced this pull request Sep 28, 2020
Summary: While tracking down a recent memory corruption bug we found that
cuda-memcheck wasn't finding the bad accesses, and @ngimel pointed out that
it's because we use a caching allocator so a lot of "out of bounds" accesses
land in a valid slab.

This PR adds a runtime knob (`PYTORCH_NO_CUDA_MEMORY_CACHING`) that, when set,
bypasses the caching allocator's caching logic so that allocations go straight
to cudaMalloc.  This way, cuda-memcheck will actually work.

Test Plan: Insert some memory errors and run a test under cuda-memcheck;
observe that cuda-memcheck flags an error where expected.

Specifically I removed the output-masking logic here:
https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/tensorexpr/cuda_codegen.cpp#L819-L826

And ran:
```
PYTORCH_NO_CUDA_MEMORY_CACHING=1 cuda-memcheck pytest -k test_superslomo test_jit_fuser_te.py
```

ghstack-source-id: 6b44289
Pull Request resolved: #45294
@facebook-github-bot
Copy link
Contributor

@bertmaher merged this pull request in 03342af.

@facebook-github-bot facebook-github-bot deleted the gh/bertmaher/23/head branch October 2, 2020 14:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants