Skip to content

Conversation

@izdeby
Copy link
Contributor

@izdeby izdeby commented Aug 24, 2020

Stack from ghstack:

Differential Revision: D23331893

Motivation
GitHub issue
Current PyTorch optimizer implementations are not efficient in cases when we work with a lot of small feature tensors. Starting a lot of kernels slows down the whole process. We need to reduce the number of kernels that we start.
As an example, we should be looking at NVIDIAs Apex.
In order to track progress, we will pick PyTorchs DCGAN model with Adam optimizer and once the optimizer is reimplemented with tensor lists, benchmark the model performance against original model version, Apexs version with original Adam optimizer and it’s FusedAdam optimizer.

Current API restrictions

  • List can't be empty (will fixed in upcoming PRs).
  • All tensors in the list must have the same dtype, device and size.

Broadcasting
At this point we don't support broadcasting.

What is 'Fast' and 'Slow' route
In particular cases, we cant process an op with a fast list CUDA kernel. Still, we can do with a regular for-loop where the op will be applied to each tensor individually through the dispatch mechanisms. There are a few checks that decide whether the op will be performed via a 'fast' or 'slow' path.
To go the fast route,

  • All tensors must have strided layout
  • All tensors must be dense and not have overlapping memory
  • The resulting tensor type must be the same dtype.
  • All Tensors must be on the same device.

In this PR

  • We are introducing new namespace under torch.optim - torch.optim.multi_tensor, where we will have optimizers rewritten with foreach* APIs.
  • Rewriting adam optimizer with foreach* APIs

@dr-ci
Copy link

dr-ci bot commented Aug 24, 2020

💊 CI failures summary and remediations

As of commit 1402b87 (more details on the Dr. CI page):



🕵️ 9 new failures recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_linux_xenial_py3_6_gcc5_4_ge_config_legacy_test (1/9)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun)

Sep 23 22:55:19 ERROR [0.002s]: test_torchbind_tracing_nested (jit.test_torchbind.TestTorchbind)
Sep 23 22:55:19 Traceback (most recent call last): 
Sep 23 22:55:19   File "/var/lib/jenkins/workspace/test/jit/test_torchbind.py", line 31, in setUp 
Sep 23 22:55:19     torch.ops.load_library(str(p)) 
Sep 23 22:55:19   File "/opt/conda/lib/python3.6/site-packages/torch/_ops.py", line 105, in load_library 
Sep 23 22:55:19     ctypes.CDLL(path) 
Sep 23 22:55:19   File "/opt/conda/lib/python3.6/ctypes/__init__.py", line 348, in __init__ 
Sep 23 22:55:19     self._handle = _dlopen(self._name, mode) 
Sep 23 22:55:19 OSError: /var/lib/jenkins/workspace/build/lib/libtorchbind_test.so: undefined symbol: _ZTIN7testing4TestE 
Sep 23 22:55:19  
Sep 23 22:55:19 ====================================================================== 
Sep 23 22:55:19 ERROR [0.002s]: test_torchbind_tracing_nested (jit.test_torchbind.TestTorchbind) 
Sep 23 22:55:19 ---------------------------------------------------------------------- 
Sep 23 22:55:19 Traceback (most recent call last): 
Sep 23 22:55:19   File "/var/lib/jenkins/workspace/test/jit/test_torchbind.py", line 31, in setUp 
Sep 23 22:55:19     torch.ops.load_library(str(p)) 
Sep 23 22:55:19   File "/opt/conda/lib/python3.6/site-packages/torch/_ops.py", line 105, in load_library 
Sep 23 22:55:19     ctypes.CDLL(path) 
Sep 23 22:55:19   File "/opt/conda/lib/python3.6/ctypes/__init__.py", line 348, in __init__ 
Sep 23 22:55:19     self._handle = _dlopen(self._name, mode) 
Sep 23 22:55:19 OSError: /var/lib/jenkins/workspace/build/lib/libtorchbind_test.so: undefined symbol: _ZTIN7testing4TestE 
Sep 23 22:55:19  

See CircleCI build pytorch_linux_xenial_py3_clang5_asan_test2 (2/9)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun)

Sep 23 22:50:45 SUMMARY: UndefinedBehaviorSanitizer: undefined-behavior /var/lib/jenkins/workspace/aten/src/ATen/Utils.cpp:11:3 in
Sep 23 22:50:45     #7 0x55a8090c570b in PyEval_EvalCode /tmp/build/80754af9/python_1599604603603/work/Python/ceval.c:731 
Sep 23 22:50:45     #8 0x55a809145573 in run_mod /tmp/build/80754af9/python_1599604603603/work/Python/pythonrun.c:1025 
Sep 23 22:50:45     #9 0x55a80914560c in PyRun_StringFlags /tmp/build/80754af9/python_1599604603603/work/Python/pythonrun.c:949 
Sep 23 22:50:45     #10 0x55a80914566e in PyRun_SimpleStringFlags /tmp/build/80754af9/python_1599604603603/work/Python/pythonrun.c:445 
Sep 23 22:50:45     #11 0x55a809149472 in run_command /tmp/build/80754af9/python_1599604603603/work/Modules/main.c:301 
Sep 23 22:50:45     #12 0x55a809149472 in Py_Main /tmp/build/80754af9/python_1599604603603/work/Modules/main.c:749 
Sep 23 22:50:45     #13 0x55a80901343d in main /tmp/build/80754af9/python_1599604603603/work/Programs/python.c:69 
Sep 23 22:50:45     #14 0x7f7814aa583f in __libc_start_main /build/glibc-e6zv40/glibc-2.23/csu/../csu/libc-start.c:291 
Sep 23 22:50:45     #15 0x55a8090f2d0a in _start /home/rdonnelly/mc/conda-bld/compilers_linux-64_1534865402226/work/.build/src/glibc-2.12.2/csu/../sysdeps/x86_64/elf/start.S:103 
Sep 23 22:50:45  
Sep 23 22:50:45 SUMMARY: UndefinedBehaviorSanitizer: undefined-behavior /var/lib/jenkins/workspace/aten/src/ATen/Utils.cpp:11:3 in  
Sep 23 22:50:45 + retcode=1 
Sep 23 22:50:45 + set -e 
Sep 23 22:50:45 + return 1 
Sep 23 22:50:45 + [[ pytorch-linux-xenial-py3-clang5-asan-test2 == *-NO_AVX-* ]] 
Sep 23 22:50:45 + [[ pytorch-linux-xenial-py3-clang5-asan-test2 == *-NO_AVX2-* ]] 
Sep 23 22:50:45 + '[' -n https://github.com/pytorch/pytorch/pull/43507 ']' 
Sep 23 22:50:45 + [[ pytorch-linux-xenial-py3-clang5-asan-test2 != *coverage* ]] 
Sep 23 22:50:45 ++ mktemp 
Sep 23 22:50:45 + DETERMINE_FROM=/tmp/tmp.L1uKF74yUQ 
Sep 23 22:50:45 + file_diff_from_base /tmp/tmp.L1uKF74yUQ 

See CircleCI build pytorch_linux_xenial_py3_6_gcc5_4_ge_config_profiling_test (3/9)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun)

Sep 23 22:54:17 ERROR [0.002s]: test_torchbind_tracing_nested (jit.test_torchbind.TestTorchbind)
Sep 23 22:54:17 Traceback (most recent call last): 
Sep 23 22:54:17   File "/var/lib/jenkins/workspace/test/jit/test_torchbind.py", line 31, in setUp 
Sep 23 22:54:17     torch.ops.load_library(str(p)) 
Sep 23 22:54:17   File "/opt/conda/lib/python3.6/site-packages/torch/_ops.py", line 105, in load_library 
Sep 23 22:54:17     ctypes.CDLL(path) 
Sep 23 22:54:17   File "/opt/conda/lib/python3.6/ctypes/__init__.py", line 348, in __init__ 
Sep 23 22:54:17     self._handle = _dlopen(self._name, mode) 
Sep 23 22:54:17 OSError: /var/lib/jenkins/workspace/build/lib/libtorchbind_test.so: undefined symbol: _ZTIN7testing4TestE 
Sep 23 22:54:17  
Sep 23 22:54:17 ====================================================================== 
Sep 23 22:54:17 ERROR [0.002s]: test_torchbind_tracing_nested (jit.test_torchbind.TestTorchbind) 
Sep 23 22:54:17 ---------------------------------------------------------------------- 
Sep 23 22:54:17 Traceback (most recent call last): 
Sep 23 22:54:17   File "/var/lib/jenkins/workspace/test/jit/test_torchbind.py", line 31, in setUp 
Sep 23 22:54:17     torch.ops.load_library(str(p)) 
Sep 23 22:54:17   File "/opt/conda/lib/python3.6/site-packages/torch/_ops.py", line 105, in load_library 
Sep 23 22:54:17     ctypes.CDLL(path) 
Sep 23 22:54:17   File "/opt/conda/lib/python3.6/ctypes/__init__.py", line 348, in __init__ 
Sep 23 22:54:17     self._handle = _dlopen(self._name, mode) 
Sep 23 22:54:17 OSError: /var/lib/jenkins/workspace/build/lib/libtorchbind_test.so: undefined symbol: _ZTIN7testing4TestE 
Sep 23 22:54:17  

See CircleCI build pytorch_windows_vs2019_py36_cuda10.1_test2 (4/9)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

RuntimeError: test_optim failed!
 
Generating XML reports... 
Generated XML report: test-reports\python-unittest\TEST-TestLRScheduler-20200923233632.xml 
Generated XML report: test-reports\python-unittest\TEST-TestOptim-20200923233632.xml 
Generated XML report: test-reports\python-unittest\TEST-TestSWAUtils-20200923233632.xml 
Traceback (most recent call last): 
  File "run_test.py", line 742, in <module> 
    main() 
  File "run_test.py", line 725, in main 
    raise RuntimeError(err_message) 
RuntimeError: test_optim failed! 
 
(base) circleci@PACKER-5F5FCBA1 C:\Users\circleci\project\test>if ERRORLEVEL 1 exit /b 1  
+ cleanup
+ retcode=1
+ set +x

See CircleCI build pytorch_macos_10_13_py3_test (5/9)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

Sep 23 23:30:53 AssertionError: Torch not compiled with CUDA enabled
Sep 23 23:30:53 TypeError: expected Tensor as element 0 in argument 1, but got float 
Sep 23 23:30:53  
Sep 23 23:30:53 ====================================================================== 
Sep 23 23:30:53 FAIL [0.004s]: test_adam_step (__main__.TestOptim) 
Sep 23 23:30:53 ---------------------------------------------------------------------- 
Sep 23 23:30:53 Traceback (most recent call last): 
Sep 23 23:30:53   File "test_optim.py", line 302, in test_adam_step 
Sep 23 23:30:53     weight_base = torch.randn(10, 5, requires_grad=True, device='cuda') 
Sep 23 23:30:53   File "/Users/distiller/workspace/miniconda3/lib/python3.7/site-packages/torch/cuda/__init__.py", line 165, in _lazy_init 
Sep 23 23:30:53     raise AssertionError("Torch not compiled with CUDA enabled") 
Sep 23 23:30:53 AssertionError: Torch not compiled with CUDA enabled 
Sep 23 23:30:53  
Sep 23 23:30:53 ---------------------------------------------------------------------- 
Sep 23 23:30:53 Ran 103 tests in 29.526s 
Sep 23 23:30:53  
Sep 23 23:30:53 FAILED (failures=1, errors=1) 
Sep 23 23:30:53  
Sep 23 23:30:53 Generating XML reports... 
Sep 23 23:30:53 Generated XML report: test-reports/dist-gloo/TEST-TestLRScheduler-20200923233023.xml 
Sep 23 23:30:53 Generated XML report: test-reports/dist-gloo/TEST-TestOptim-20200923233023.xml 
Sep 23 23:30:53 Generated XML report: test-reports/dist-gloo/TEST-TestSWAUtils-20200923233023.xml 

See CircleCI build pytorch_linux_bionic_py3_6_clang9_test (6/9)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun)

Sep 23 23:23:25 ERROR:sccache::server: Compilation failed: Output { status: ExitStatus(ExitStatus(256)), stdout: "", stderr: "/var/lib/jenkins/.cache/torch_extensions/test_compilation_error_formatting/main.cpp: In function \'int main()\':\n/var/lib/jenkins/.cache/torch_extensions/test_compilation_error_formatting/main.cpp:2:23: error: expected \';\' before \'}\' token\n int main() { return 0 }\n ^\n" }
Sep 23 23:23:25     raise RuntimeError(err_message) 
Sep 23 23:23:25 RuntimeError: test_optim failed! 
Sep 23 23:23:25  
Sep 23 23:23:25 real	24m57.321s 
Sep 23 23:23:25 user	30m56.667s 
Sep 23 23:23:25 sys	5m27.241s 
Sep 23 23:23:25 + cleanup 
Sep 23 23:23:25 + retcode=1 
Sep 23 23:23:25 + set +x 
Sep 23 23:23:25 =================== sccache compilation log =================== 
Sep 23 23:23:25 ERROR:sccache::server: Compilation failed: Output { status: ExitStatus(ExitStatus(256)), stdout: "", stderr: "/var/lib/jenkins/.cache/torch_extensions/test_compilation_error_formatting/main.cpp: In function \'int main()\':\n/var/lib/jenkins/.cache/torch_extensions/test_compilation_error_formatting/main.cpp:2:23: error: expected \';\' before \'}\' token\n int main() { return 0 }\n                       ^\n" } 
Sep 23 23:23:25  
Sep 23 23:23:25 =========== If your build fails, please take a look at the log above for possible reasons =========== 
Sep 23 23:23:25 Compile requests                327 
Sep 23 23:23:25 Compile requests executed        35 
Sep 23 23:23:25 Cache hits                       27 
Sep 23 23:23:25 Cache misses                      7 
Sep 23 23:23:25 Cache timeouts                    0 
Sep 23 23:23:25 Cache read errors                 0 
Sep 23 23:23:25 Forced recaches                   0 
Sep 23 23:23:25 Cache write errors                0 

See CircleCI build pytorch_linux_xenial_py3_6_gcc5_4_test (7/9)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun)

Sep 23 23:28:16 ERROR:sccache::server: Compilation failed: Output { status: ExitStatus(ExitStatus(256)), stdout: "", stderr: "/var/lib/jenkins/.cache/torch_extensions/test_compilation_error_formatting/main.cpp: In function \'int main()\':\n/var/lib/jenkins/.cache/torch_extensions/test_compilation_error_formatting/main.cpp:2:23: error: expected \';\' before \'}\' token\n int main() { return 0 }\n ^\n" }
Sep 23 23:28:16 Traceback (most recent call last): 
Sep 23 23:28:16   File "test/run_test.py", line 742, in <module> 
Sep 23 23:28:16     main() 
Sep 23 23:28:16   File "test/run_test.py", line 725, in main 
Sep 23 23:28:16     raise RuntimeError(err_message) 
Sep 23 23:28:16 RuntimeError: test_jit failed! 
Sep 23 23:28:16 + cleanup 
Sep 23 23:28:16 + retcode=1 
Sep 23 23:28:16 + set +x 
Sep 23 23:28:16 =================== sccache compilation log =================== 
Sep 23 23:28:16 ERROR:sccache::server: Compilation failed: Output { status: ExitStatus(ExitStatus(256)), stdout: "", stderr: "/var/lib/jenkins/.cache/torch_extensions/test_compilation_error_formatting/main.cpp: In function \'int main()\':\n/var/lib/jenkins/.cache/torch_extensions/test_compilation_error_formatting/main.cpp:2:23: error: expected \';\' before \'}\' token\n int main() { return 0 }\n                       ^\n" } 
Sep 23 23:28:16  
Sep 23 23:28:16 =========== If your build fails, please take a look at the log above for possible reasons =========== 
Sep 23 23:28:16 Compile requests                327 
Sep 23 23:28:16 Compile requests executed        35 
Sep 23 23:28:16 Cache hits                       27 
Sep 23 23:28:16 Cache misses                      7 
Sep 23 23:28:16 Cache timeouts                    0 
Sep 23 23:28:16 Cache read errors                 0 
Sep 23 23:28:16 Forced recaches                   0 
Sep 23 23:28:16 Cache write errors                0 

See CircleCI build pytorch_linux_xenial_py3_6_gcc5_4_ge_config_simple_test (8/9)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun)

Sep 23 23:23:57 ERROR:sccache::server: Compilation failed: Output { status: ExitStatus(ExitStatus(256)), stdout: "", stderr: "/var/lib/jenkins/.cache/torch_extensions/test_compilation_error_formatting/main.cpp: In function \'int main()\':\n/var/lib/jenkins/.cache/torch_extensions/test_compilation_error_formatting/main.cpp:2:23: error: expected \';\' before \'}\' token\n int main() { return 0 }\n ^\n" }
Sep 23 23:23:57 Traceback (most recent call last): 
Sep 23 23:23:57   File "test/run_test.py", line 742, in <module> 
Sep 23 23:23:57     main() 
Sep 23 23:23:57   File "test/run_test.py", line 725, in main 
Sep 23 23:23:57     raise RuntimeError(err_message) 
Sep 23 23:23:57 RuntimeError: test_jit failed! 
Sep 23 23:23:57 =================== sccache compilation log =================== 
Sep 23 23:23:57 + cleanup 
Sep 23 23:23:57 + retcode=1 
Sep 23 23:23:57 + set +x 
Sep 23 23:23:57 ERROR:sccache::server: Compilation failed: Output { status: ExitStatus(ExitStatus(256)), stdout: "", stderr: "/var/lib/jenkins/.cache/torch_extensions/test_compilation_error_formatting/main.cpp: In function \'int main()\':\n/var/lib/jenkins/.cache/torch_extensions/test_compilation_error_formatting/main.cpp:2:23: error: expected \';\' before \'}\' token\n int main() { return 0 }\n                       ^\n" } 
Sep 23 23:23:57  
Sep 23 23:23:57 =========== If your build fails, please take a look at the log above for possible reasons =========== 
Sep 23 23:23:57 Compile requests                327 
Sep 23 23:23:57 Compile requests executed        35 
Sep 23 23:23:57 Cache hits                       27 
Sep 23 23:23:57 Cache misses                      7 
Sep 23 23:23:57 Cache timeouts                    0 
Sep 23 23:23:57 Cache read errors                 0 
Sep 23 23:23:57 Forced recaches                   0 
Sep 23 23:23:57 Cache write errors                0 

See CircleCI build pytorch_linux_bionic_py3_8_gcc9_coverage_test (9/9)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun)

Sep 23 23:28:59 ERROR:sccache::server: Compilation failed: Output { status: ExitStatus(ExitStatus(256)), stdout: "", stderr: "/var/lib/jenkins/.cache/torch_extensions/test_compilation_error_formatting/main.cpp: In function ‘int main()’:\n/var/lib/jenkins/.cache/torch_extensions/test_compilation_error_formatting/main.cpp:2:22: error: expected ‘;’ before ‘}’ token\n 2 | int main() { return 0 }\n | ^~\n | ;\n" }
Sep 23 23:28:59     raise RuntimeError(err_message) 
Sep 23 23:28:59 RuntimeError: test_optim failed! 
Sep 23 23:28:59  
Sep 23 23:28:59 real	23m30.062s 
Sep 23 23:28:59 user	28m32.428s 
Sep 23 23:28:59 sys	3m9.095s 
Sep 23 23:28:59 + cleanup 
Sep 23 23:28:59 + retcode=1 
Sep 23 23:28:59 + set +x 
Sep 23 23:28:59 =================== sccache compilation log =================== 
Sep 23 23:28:59 ERROR:sccache::server: Compilation failed: Output { status: ExitStatus(ExitStatus(256)), stdout: "", stderr: "/var/lib/jenkins/.cache/torch_extensions/test_compilation_error_formatting/main.cpp: In function ‘int main()’:\n/var/lib/jenkins/.cache/torch_extensions/test_compilation_error_formatting/main.cpp:2:22: error: expected ‘;’ before ‘}’ token\n    2 | int main() { return 0 }\n      |                      ^~\n      |                      ;\n" } 
Sep 23 23:28:59  
Sep 23 23:28:59 =========== If your build fails, please take a look at the log above for possible reasons =========== 
Sep 23 23:28:59 Compile requests                327 
Sep 23 23:28:59 Compile requests executed        35 
Sep 23 23:28:59 Cache hits                       27 
Sep 23 23:28:59 Cache misses                      7 
Sep 23 23:28:59 Cache timeouts                    0 
Sep 23 23:28:59 Cache read errors                 0 
Sep 23 23:28:59 Forced recaches                   0 
Sep 23 23:28:59 Cache write errors                0 

❄️ 1 failure tentatively classified as flaky

but reruns have not yet been triggered to confirm:

See CircleCI build pytorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_test (1/1)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun) ❄️

Sep 23 23:59:18 ConnectionResetError: [Errno 104] Connection reset by peer
Sep 23 23:59:18   File "/opt/conda/lib/python3.6/multiprocessing/connection.py", line 455, in accept 
Sep 23 23:59:18     deliver_challenge(c, self._authkey) 
Sep 23 23:59:18   File "/opt/conda/lib/python3.6/multiprocessing/connection.py", line 722, in deliver_challenge 
Sep 23 23:59:18     response = connection.recv_bytes(256)        # reject large message 
Sep 23 23:59:18   File "/opt/conda/lib/python3.6/multiprocessing/connection.py", line 216, in recv_bytes 
Sep 23 23:59:18     buf = self._recv_bytes(maxlength) 
Sep 23 23:59:18   File "/opt/conda/lib/python3.6/multiprocessing/connection.py", line 407, in _recv_bytes 
Sep 23 23:59:18     buf = self._recv(4) 
Sep 23 23:59:18   File "/opt/conda/lib/python3.6/multiprocessing/connection.py", line 379, in _recv 
Sep 23 23:59:18     chunk = read(handle, remaining) 
Sep 23 23:59:18 ConnectionResetError: [Errno 104] Connection reset by peer 
Sep 23 23:59:18 /opt/conda/lib/python3.6/multiprocessing/semaphore_tracker.py:143: UserWarning: semaphore_tracker: There appear to be 14 leaked semaphores to clean up at shutdown 
Sep 23 23:59:18   len(cache)) 
Sep 23 23:59:21 Process ErrorTrackingProcess-380: 
Sep 23 23:59:21 Traceback (most recent call last): 
Sep 23 23:59:21   File "/opt/conda/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap 
Sep 23 23:59:21     self.run() 
Sep 23 23:59:21   File "/var/lib/jenkins/workspace/test/test_dataloader.py", line 361, in run 
Sep 23 23:59:21     super(ErrorTrackingProcess, self).run() 
Sep 23 23:59:21   File "/opt/conda/lib/python3.6/multiprocessing/process.py", line 93, in run 
Sep 23 23:59:21     self._target(*self._args, **self._kwargs) 

Extra GitHub checks: 1 failed


ci.pytorch.org: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 131 times.

izdeby pushed a commit that referenced this pull request Aug 24, 2020
ghstack-source-id: 6231ead
Pull Request resolved: #43507
@izdeby izdeby changed the title [WIP] Rewrote adam optimizer with foreach APIs Rewrote adam optimizer with foreach APIs Sep 8, 2020
Copy link
Collaborator

@ngimel ngimel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This generally looks good, but it reveals that we need a few more (TensorList, ScalarList) operations, because now there are still a few loops over parameters/grads.

for p in group['params']:
if p.grad is not None:
params_with_grad.append(p)
grads.append(p.grad)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you could check for sparse gradients here, not in a separate loop

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

grads.append(p.grad)

for p in params_with_grad:
for g in grads:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove loop over grads, check them earlier.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

bias_correction1 = [1 - beta1 ** state['step'] for state in states]
bias_correction2 = [1 - beta2 ** state['step'] for state in states]
if group['weight_decay'] != 0:
torch._foreach_add_(grads, group['params'], group['weight_decay'])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the original optimizer the line is

grad = grad.add(p, alpha=group['weight_decay'])

This is not inplace, and original grad attributes (p.grad) aren't mutated. Here you are mutating p.grad inplace (I'm surprised tests don't catch it).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
max_exp_avg_sq = [torch.max(a, b) for a, b in zip(max_exp_avg_sq, exp_avg_sq)]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aha, so ideally we also need a foreach max, because now it will be a loop?

# Use the max. for normalizing running avg. of gradient
max_exp_avg_sq_sqrt = torch._foreach_sqrt(max_exp_avg_sq)
bias_correction_sqrt = [math.sqrt(bc) for bc in bias_correction2]
max_exp_avg_sq_sqrt = [torch.div(a, b) for a, b in zip(max_exp_avg_sq_sqrt, bias_correction_sqrt)]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, so this is a loop because we don't have Op(TensorList, ScalarList)? This is unfortunate, looks like we really need it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per conversation in slack, we can check if all steps are the same, that will guarantee that all bias_corrections are the same, and we can use for_each here and in other places if it's the case (should be common).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you suggesting a change in the underlying algorithm? What does it mean and what would we happens if all steps are not the same?

step_size = [group['lr'] / bc for bc in bias_correction1]

for i in range(len(step_size)):
params_with_grad[i].addcdiv_(exp_avg[i], denom[i], value=-step_size[i])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And this is another case where we need op(TensorList, ScalarList), because params_with_grad is a tensorlist, exp_avg is a tensorlist, denom is a TensorList, and value is ScalarList?

Differential Revision: [D23331893](https://our.internmc.facebook.com/intern/diff/D23331893)

**Motivation**
[GitHub issue](#38655) 
Current PyTorch optimizer implementations are not efficient in cases when we work with a lot of small feature tensors. Starting a lot of kernels slows down the whole process. We need to reduce the number of kernels that we start. 
As an example, we should be looking at [NVIDIAs Apex](https://github.com/NVIDIA/apex). 
In order to track progress, we will pick PyTorchs DCGAN model with Adam optimizer and once the optimizer is reimplemented with tensor lists, benchmark the model performance against original model version, Apexs version with original Adam optimizer and it’s FusedAdam optimizer.

**Current API restrictions**
- List can't be empty (will fixed in upcoming PRs). 
- All tensors in the list must have the same dtype, device and size.

**Broadcasting**
At this point we don't support broadcasting. 

**What is 'Fast' and 'Slow' route**
In particular cases, we cant process an op with a fast list CUDA kernel. Still, we can do with a regular for-loop where the op will be applied to each tensor individually through the dispatch mechanisms. There are a few checks that decide whether the op will be performed via a 'fast' or 'slow' path. 
To go the fast route,
- All tensors must have strided layout
- All tensors must be dense and not have overlapping memory
- The resulting tensor type must be the same dtype.
- All Tensors must be on the same device. 


----------------
**In this PR**
- We are introducing new namespace under torch.optim - torch.optim.multi_tensor, where we will have optimizers rewritten with _foreach_* APIs.
- Rewriting adam optimizer with _foreach_* APIs

[ghstack-poisoned]
max_exp_avg_sq = [torch.max(a, b) for a, b in zip(max_exp_avg_sq, exp_avg_sq)]
# Use the max. for normalizing running avg. of gradient
max_exp_avg_sq_sqrt = torch._foreach_sqrt(max_exp_avg_sq)
bias_correction_sqrt = [math.sqrt(bc) for bc in bias_correction2]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this is not a for_each_sqrt?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in this context, bias_correction2 is a list of scalars. there is no foreach api that supports lists of scalars. im working on those APIs right now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This particular case is ok, those are python scalars and this operation is reasonably fast. However, other cases where cuda tensors are interacting with a list of scalars are more problematic. We could get around them for now by checking if all bias_corrections have the same value (which should be a common case).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I gave myself 2 days to fight codegen and make ScalarList a reality. if it works out - i will just add new APIs for _foreach_op(TensorList, ScalarList). And if it will be too complex, will make a workaround and add this to TODO.

@vincentqb
Copy link
Contributor

vincentqb commented Sep 10, 2020

In case there is a way of avoiding duplicating code:

  • What would be the advantage of using the previous version? One noted in the description is that the dtype has to be the same. (We did mention offline that MultiTensor would not be differentiable for instance.)
  • Can TensorList be reduced to list of tensors if a list of tensors is provided instead? If so could we consider having only this new implementation for each optimizer?
  • Should we consider having a wrapper that swaps between the two current implementation depending on the type?

Broadcasting
At this point we don't support broadcasting.

This mean some operations could be slower with MultiTensor for now though, right? Has this been measured?

@vincentqb
Copy link
Contributor

GitHub issue
Current PyTorch optimizer implementations are not efficient in cases when we work with a lot of small feature tensors. Starting a lot of kernels slows down the whole process. We need to reduce the number of kernels that we start.
As an example, we should be looking at NVIDIAs Apex.
In order to track progress, we will pick PyTorchs DCGAN model with Adam optimizer and once the optimizer is reimplemented with tensor lists, benchmark the model performance against original model version, Apexs version with original Adam optimizer and it’s FusedAdam optimizer.

Now that we have the implementation, have you been able to quantify the performance scaling?


with self.assertRaisesRegex(ValueError, "Invalid weight_decay value: -1"):
optim.Adam(None, lr=1e-2, weight_decay=-1)
for optimizer in [optim.Adam, optim_mt.Adam]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: current tests don't guarantee that the algorithm is the same as the other, or even that there's convergence in either case.

@zou3519 @anjali411 -- how are the C++ APIs tested?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

C++ Optimizer API logic is tested here: https://github.com/pytorch/pytorch/blob/master/test/cpp/api/optim.cpp#L311. These tests compare the C++ optimizers' results to the Python API optimizers' results prewritten in this file: https://github.com/pytorch/pytorch/blob/master/test/cpp/api/optim_baseline.h

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, so here we'll need the tests comparing optim.Adam results with optim_mt.Adam, in addition to _test_basic_cases?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, since the contract here is that the MultiTensor implementation is the same as the original.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there also a C++ equivalent with the foreach API for the C++ optimizer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vincentqb, there will be C++ optimizers as well but a bit later. We decided to start with python ones first.

re testing: Is there anything specific you would suggest testing?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, we'd have a test as mentioned by @anjali411 above that checks that the two implementations give the exact same answer in some case.

Note that if this implementation were directly replacing the current one, the C++ test would also tell us that the implementations are still aligned :) Maybe there's a way of leveraging those tests already there?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are tests in apex comparing optimizer implementations, they can be adopted here if c++ tests are hard to use for some reason https://github.com/NVIDIA/apex/blob/master/tests/L0/run_optimizers/test_fused_optimizer.py

@izdeby izdeby changed the title Rewrote adam optimizer with foreach APIs [WIP] Rewrote adam optimizer with foreach APIs Sep 10, 2020
Differential Revision: [D23331893](https://our.internmc.facebook.com/intern/diff/D23331893)

**Motivation**
[GitHub issue](#38655) 
Current PyTorch optimizer implementations are not efficient in cases when we work with a lot of small feature tensors. Starting a lot of kernels slows down the whole process. We need to reduce the number of kernels that we start. 
As an example, we should be looking at [NVIDIAs Apex](https://github.com/NVIDIA/apex). 
In order to track progress, we will pick PyTorchs DCGAN model with Adam optimizer and once the optimizer is reimplemented with tensor lists, benchmark the model performance against original model version, Apexs version with original Adam optimizer and it’s FusedAdam optimizer.

**Current API restrictions**
- List can't be empty (will fixed in upcoming PRs). 
- All tensors in the list must have the same dtype, device and size.

**Broadcasting**
At this point we don't support broadcasting. 

**What is 'Fast' and 'Slow' route**
In particular cases, we cant process an op with a fast list CUDA kernel. Still, we can do with a regular for-loop where the op will be applied to each tensor individually through the dispatch mechanisms. There are a few checks that decide whether the op will be performed via a 'fast' or 'slow' path. 
To go the fast route,
- All tensors must have strided layout
- All tensors must be dense and not have overlapping memory
- The resulting tensor type must be the same dtype.
- All Tensors must be on the same device. 


----------------
**In this PR**
- We are introducing new namespace under torch.optim - torch.optim.multi_tensor, where we will have optimizers rewritten with _foreach_* APIs.
- Rewriting adam optimizer with _foreach_* APIs

[ghstack-poisoned]
@stas00
Copy link
Contributor

stas00 commented Feb 2, 2021

this is wonderful! thank you!

I don't see any user docs discussing this improvement.

What is the implication for the user?

Should we switch to use torch.optim.multi_tensor if the version is right, or is the intention for it to eventually take over torch.optim and no changes need to be done to the user code.

Thank you.

@izdeby
Copy link
Contributor Author

izdeby commented Feb 2, 2021

Hi, @stas00
First of all, if you would like to track the progress of this work, please take a look at this stack.

Answering your question, you can try using the optimizers from torch.optim._multi_tensor namespace, but please keep in mind that they are in an alpha state and will be updated. Eventually, they will replace the ones from torch.optim. Once that happens, users dont have to do anything at all. The changes will happen under the hood.

@stas00
Copy link
Contributor

stas00 commented Feb 2, 2021

Thank you very much for the clarification and the stack, @izdeby!

So basically we have an option to deploy these early for those who need the speed up sooner, but otherwise there is nothing to be done.

Excellent!

@Wesley-Jzy
Copy link

If I want to implement global operators like grad_clip to reduce kernel launch by myself, may I use multi_tensor to do it or PyTorch just provide similar interfaces?

@ngimel
Copy link
Collaborator

ngimel commented Feb 27, 2023

Pytorch's clip_grad_norm uses multi_tensor, if you want to write your own utility you can use multi_tensor yourself.

@Wesley-Jzy
Copy link

Pytorch's clip_grad_norm uses multi_tensor, if you want to write your own utility you can use multi_tensor yourself.

Thank you so much. I think multi_tensor is what I'm looking for.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants