Skip to content

Commit 95a69a7

Browse files
Natalia Gimelsheinfacebook-github-bot
authored andcommitted
adds list_gpu_processes function (#44616)
Summary: per title, to make it easier to track the creation of stray contexts: ``` python -c "import torch; a=torch.randn(1, device='cuda'); print(torch.cuda.memory.list_gpu_processes(0)); print(torch.cuda.memory.list_gpu_processes(1))" GPU:0 process 79749 uses 601.000 MB GPU memory GPU:1 no processes are running ``` Pull Request resolved: #44616 Reviewed By: mruberry Differential Revision: D23675739 Pulled By: ngimel fbshipit-source-id: ffa14cad9d7144e883de13b1c2c6817bd432f53a
1 parent 105132b commit 95a69a7

File tree

1 file changed

+35
-0
lines changed

1 file changed

+35
-0
lines changed

torch/cuda/memory.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,3 +467,38 @@ def _format_count(cnt, pref_cnt):
467467
for k, v in stats.items():
468468
fmt_dict[k.replace(".", "-")] = v
469469
return "|" + "|\n|".join(lines).format(**fmt_dict) + "|\n"
470+
471+
472+
def list_gpu_processes(device: Union[Device, int] = None) -> str:
473+
r"""Returns a human-readable printout of the running processes
474+
and their GPU memory use for a given device.
475+
476+
This can be useful to display periodically during training, or when
477+
handling out-of-memory exceptions.
478+
479+
Arguments:
480+
device (torch.device or int, optional): selected device. Returns
481+
printout for the current device, given by :func:`~torch.cuda.current_device`,
482+
if :attr:`device` is ``None`` (default).
483+
"""
484+
485+
try:
486+
import pynvml # type: ignore
487+
except ModuleNotFoundError:
488+
return("pynvml module not found, please install pynvml")
489+
from pynvml import NVMLError_DriverNotLoaded
490+
try:
491+
pynvml.nvmlInit()
492+
except NVMLError_DriverNotLoaded:
493+
return ("cuda driver can't be loaded, is cuda enabled?")
494+
device = _get_device_index(device, optional=True)
495+
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
496+
procs = pynvml.nvmlDeviceGetComputeRunningProcesses(handle)
497+
lines = []
498+
lines.append(f"GPU:{device}")
499+
if len(procs) == 0:
500+
lines.append("no processes are running")
501+
for p in procs:
502+
mem = p.usedGpuMemory / (1024 * 1024)
503+
lines.append(f"process {p.pid:>10d} uses {mem:>12.3f} MB GPU memory")
504+
return "\n".join(lines)

0 commit comments

Comments
 (0)