Skip to content

Conversation

@cpuhrsch
Copy link
Contributor

This PR hooks up TBB with set_num_threads and also cleans up the code within ReduceOps.cpp a bit.

You can see that the restriction works using the following script

import torch
import argparse

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Process some integers.')
    parser.add_argument('threads', type=int)
    args = parser.parse_args()

    tv = torch.randn(1000 * 1000 * 10)
    if args.threads > 0:
        torch.set_num_threads(args.threads)
    for _ in range(10000):
        tv.sum()
(base) [16:44:39: cpuhrsch@devfair0129 benchmarks]$ perf stat python sum_stress.py 5

 Performance counter stats for 'python sum_stress.py 5':

      17327.154266      task-clock (msec)         #    3.721 CPUs utilized          
               962      context-switches          #    0.056 K/sec                  
                90      cpu-migrations            #    0.005 K/sec                  
            22,517      page-faults               #    0.001 M/sec                  
    38,071,989,548      cycles                    #    2.197 GHz                    
   <not supported>      stalled-cycles-frontend  
   <not supported>      stalled-cycles-backend   
    31,333,390,416      instructions              #    0.82  insns per cycle        
     3,732,742,496      branches                  #  215.427 M/sec                  
         8,014,202      branch-misses             #    0.21% of all branches        

       4.656548693 seconds time elapsed

(base) [16:44:44: cpuhrsch@devfair0129 benchmarks]$ perf stat python sum_stress.py 2

 Performance counter stats for 'python sum_stress.py 2':

      16708.310926      task-clock (msec)         #    1.780 CPUs utilized          
             1,893      context-switches          #    0.113 K/sec                  
                85      cpu-migrations            #    0.005 K/sec                  
            21,486      page-faults               #    0.001 M/sec                  
    36,683,508,608      cycles                    #    2.196 GHz                    
   <not supported>      stalled-cycles-frontend  
   <not supported>      stalled-cycles-backend   
    30,414,604,047      instructions              #    0.83  insns per cycle        
     3,536,252,326      branches                  #  211.646 M/sec                  
         5,265,418      branch-misses             #    0.15% of all branches        

       9.386758238 seconds time elapsed

(base) [16:51:54: cpuhrsch@devfair0129 benchmarks]$ perf stat python sum_stress.py 0

 Performance counter stats for 'python sum_stress.py 0':

      47343.842407      task-clock (msec)         #   19.877 CPUs utilized          
           135,891      context-switches          #    0.003 M/sec                  
               623      cpu-migrations            #    0.013 K/sec                  
            21,791      page-faults               #    0.460 K/sec                  
   104,047,864,558      cycles                    #    2.198 GHz                    
   <not supported>      stalled-cycles-frontend  
   <not supported>      stalled-cycles-backend   
    46,576,686,318      instructions              #    0.45  insns per cycle        
     7,054,117,917      branches                  #  148.998 M/sec                  
        33,104,250      branch-misses             #    0.47% of all branches        

       2.381825636 seconds time elapsed

(base) [16:52:18: cpuhrsch@devfair0129 benchmarks]$ perf stat python sum_stress.py 1

 Performance counter stats for 'python sum_stress.py 1':

      16819.557761      task-clock (msec)         #    0.991 CPUs utilized          
               905      context-switches          #    0.054 K/sec                  
                79      cpu-migrations            #    0.005 K/sec                  
            23,798      page-faults               #    0.001 M/sec                  
    36,957,423,434      cycles                    #    2.197 GHz                    
   <not supported>      stalled-cycles-frontend  
   <not supported>      stalled-cycles-backend   
    30,260,079,596      instructions              #    0.82  insns per cycle        
     3,503,067,589      branches                  #  208.273 M/sec                  
         5,049,148      branch-misses             #    0.14% of all branches        

      16.969856749 seconds time elapsed

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@cpuhrsch cpuhrsch force-pushed the ropcleanup branch 2 times, most recently from ef0e11a to 831945d Compare March 12, 2018 23:44
@cpuhrsch
Copy link
Contributor Author

Added support for changing number of threads at runtime. Consider the following script

import torch
import time
import gc
import argparse

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Process some integers.')
    parser.add_argument('threads', type=int)
    parser.add_argument('--decrease', action='store_true')
    args = parser.parse_args()

    tv = torch.randn(1000 * 1000 * 10)
    num_thread = args.threads
    torch.set_num_threads(num_thread)
    tstart = time.time()
    for i in range(100000):
        tv.sum()
        if args.decrease and i % 10000 == 0:
            gc.collect()
            elapsed = time.time() - tstart
            tstart = time.time()
            num_thread = num_thread / 2
            print("decreasing to: " + str(num_thread) + " prev elapsed: " + str(elapsed))
            torch.set_num_threads(num_thread)

And output

decreasing to: 16 prev elapsed: 0.0332319736481
decreasing to: 8 prev elapsed: 1.24536299706
decreasing to: 4 prev elapsed: 2.19691586494
decreasing to: 2 prev elapsed: 4.26764798164
decreasing to: 1 prev elapsed: 7.82375693321
decreasing to: 0 prev elapsed: 15.4659228325
decreasing to: 0 prev elapsed: 0.84959602356
decreasing to: 0 prev elapsed: 0.842261075974
decreasing to: 0 prev elapsed: 0.804984092712
decreasing to: 0 prev elapsed: 0.823231220245

 Performance counter stats for 'python sum_stress.py 32 --decrease':

     354283.766260      task-clock (msec)         #    9.613 CPUs utilized          
           581,664      context-switches          #    0.002 M/sec                  
             4,625      cpu-migrations            #    0.013 K/sec                  
            49,221      page-faults               #    0.139 K/sec                  
   779,103,968,909      cycles                    #    2.199 GHz                    
   <not supported>      stalled-cycles-frontend  
   <not supported>      stalled-cycles-backend   
   395,023,334,546      instructions              #    0.51  insns per cycle        
    56,253,880,001      branches                  #  158.782 M/sec                  
       165,331,982      branch-misses             #    0.29% of all branches        

      36.854981344 seconds time elapsed

@soumith
Copy link
Contributor

soumith commented Mar 12, 2018

can you also make tbb / set_num_threads respect OMP_NUM_THREADS env variable. It's quite common to use it.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@cpuhrsch cpuhrsch force-pushed the ropcleanup branch 4 times, most recently from fa8f2d4 to 354a2f6 Compare March 14, 2018 01:58
@cpuhrsch
Copy link
Contributor Author

We came upon one issues related to std::thread and tbb. When running e.g. sum within a std::thread, the number of threads chosen will not be respected. So, even if I launch the problem with OMP_NUM_THREADS=1, the std::thread will use the default (all available cores). This causes issues with autograd (engine.cpp). The proposed solution is to only set the number of threads and then have this global variable be read by the respective threads. Then you create a static tbb init object within each parallel function that will initialize itself with the current number of threads. This object is updated only if the number of threads has been changed. There does not appear to be any noticeable performance penalty for this and it also appears to resolve the issue.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@cpuhrsch cpuhrsch force-pushed the ropcleanup branch 3 times, most recently from d775ee2 to a1cfb83 Compare March 15, 2018 18:11
@cpuhrsch cpuhrsch changed the title ReduceOps cleanup and set_num_threads tbb set_num_threads Mar 15, 2018
@cpuhrsch cpuhrsch force-pushed the ropcleanup branch 2 times, most recently from 9805a21 to adae78e Compare March 15, 2018 19:49

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@cpuhrsch cpuhrsch force-pushed the ropcleanup branch 4 times, most recently from 1707b18 to f3937a0 Compare March 16, 2018 16:22

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@cpuhrsch cpuhrsch force-pushed the ropcleanup branch 2 times, most recently from 7a66ecc to da33c22 Compare March 18, 2018 05:11
@cpuhrsch cpuhrsch force-pushed the ropcleanup branch 2 times, most recently from 4eb4314 to 07a1663 Compare March 18, 2018 23:58
@soumith soumith merged commit 84400d5 into pytorch:master Mar 19, 2018
soumith added a commit that referenced this pull request Mar 19, 2018
soumith added a commit that referenced this pull request Mar 19, 2018
…simple") moving average" (#5892)

* Revert "Port ATen and JIT C++ tests to Catch2 (#5788)"

This reverts commit 6f80023.

* Revert "Fix error message for cat-ing zero-dim tensors (#5819)"

This reverts commit cf2e176.

* Revert "Softmax symbolic should account for negative dim (#5846)"

This reverts commit ba64724.

* Revert "[fft][1 of 3] build system and helpers to support cuFFT and MKL (#5855)"

This reverts commit 22ef8e5.

* Revert "Don't modify requires_grad when running DataParallel in no_grad mode (#5880)"

This reverts commit d11b7fb.

* Revert "fix some methods not showing up in doc (#5882)"

This reverts commit 24fca0e.

* Revert "ReduceOps cleanup and set_num_threads (#5723)"

This reverts commit 84400d5.

* Revert "introduce shape_as_tensor and reshape_from_variable_shape (#5824)"

This reverts commit f446b82.

* Revert "Enable resetting of batchnorm running moments and cumulative ("simple") moving average (#5766)"

This reverts commit 99b1f6c.
jekbradbury pushed a commit to jekbradbury/pytorch that referenced this pull request Mar 21, 2018
…simple") moving average" (pytorch#5892)

* Revert "Port ATen and JIT C++ tests to Catch2 (pytorch#5788)"

This reverts commit 6f80023.

* Revert "Fix error message for cat-ing zero-dim tensors (pytorch#5819)"

This reverts commit cf2e176.

* Revert "Softmax symbolic should account for negative dim (pytorch#5846)"

This reverts commit ba64724.

* Revert "[fft][1 of 3] build system and helpers to support cuFFT and MKL (pytorch#5855)"

This reverts commit 22ef8e5.

* Revert "Don't modify requires_grad when running DataParallel in no_grad mode (pytorch#5880)"

This reverts commit d11b7fb.

* Revert "fix some methods not showing up in doc (pytorch#5882)"

This reverts commit 24fca0e.

* Revert "ReduceOps cleanup and set_num_threads (pytorch#5723)"

This reverts commit 84400d5.

* Revert "introduce shape_as_tensor and reshape_from_variable_shape (pytorch#5824)"

This reverts commit f446b82.

* Revert "Enable resetting of batchnorm running moments and cumulative ("simple") moving average (pytorch#5766)"

This reverts commit 99b1f6c.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants