Skip to content

Conversation

@nickgg
Copy link
Contributor

@nickgg nickgg commented Aug 5, 2020

Adds a new optimization pass, the Registerizer, which looks for common Stores and Loads to a single item in a buffer and replaces them with a local temporary scalar which is cheaper to write.

For example it can replace:

A[0] = 0;
for (int x = 0; x < 10; x++) {
  A[0] = (A[0]) + x;
}

with:

int A_ = 0;
for (int x = 0; x < 10; x++) {
  A_ = x + A_;
}
A[0] = A_;

This is particularly useful on GPUs when parallelizing, since after replacing loops with metavars we have a lot of accesses like this. Early tests of simple reductions on a V100 indicates this can speed them up by ~5x.

This diff got a bit unwieldy with the integration code so that will come in a follow up.

@nickgg nickgg requested a review from apaszke as a code owner August 5, 2020 17:00
@nickgg nickgg requested review from ZolotukhinM and zheng-xq August 5, 2020 17:05
@dr-ci
Copy link

dr-ci bot commented Aug 5, 2020

💊 CI failures summary and remediations

As of commit fb6a08d (more details on the Dr. CI page):


  • 4/4 failures possibly* introduced in this PR
    • 2/4 non-CircleCI failure(s)

🕵️ 2 new failures recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_linux_xenial_py3_6_gcc5_4_test (1/2)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun)

Aug 11 05:05:30 ERROR:sccache::server: Compilation failed: Output { status: ExitStatus(ExitStatus(256)), stdout: "", stderr: "/var/lib/jenkins/.cache/torch_extensions/test_compilation_error_formatting/main.cpp: In function \'int main()\':\n/var/lib/jenkins/.cache/torch_extensions/test_compilation_error_formatting/main.cpp:2:23: error: expected \';\' before \'}\' token\n int main() { return 0 }\n ^\n" }
Aug 11 05:05:30 Traceback (most recent call last): 
Aug 11 05:05:30   File "test/run_test.py", line 716, in <module> 
Aug 11 05:05:30     main() 
Aug 11 05:05:30   File "test/run_test.py", line 705, in main 
Aug 11 05:05:30     raise RuntimeError(err) 
Aug 11 05:05:30 RuntimeError: test_quantization failed! 
Aug 11 05:05:30 + cleanup 
Aug 11 05:05:30 + retcode=1 
Aug 11 05:05:30 + set +x 
Aug 11 05:05:30 =================== sccache compilation log =================== 
Aug 11 05:05:30 ERROR:sccache::server: Compilation failed: Output { status: ExitStatus(ExitStatus(256)), stdout: "", stderr: "/var/lib/jenkins/.cache/torch_extensions/test_compilation_error_formatting/main.cpp: In function \'int main()\':\n/var/lib/jenkins/.cache/torch_extensions/test_compilation_error_formatting/main.cpp:2:23: error: expected \';\' before \'}\' token\n int main() { return 0 }\n                       ^\n" } 
Aug 11 05:05:30  
Aug 11 05:05:30 =========== If your build fails, please take a look at the log above for possible reasons =========== 
Aug 11 05:05:30 Compile requests                 65 
Aug 11 05:05:30 Compile requests executed        35 
Aug 11 05:05:30 Cache hits                       27 
Aug 11 05:05:30 Cache misses                      7 
Aug 11 05:05:30 Cache timeouts                    0 
Aug 11 05:05:30 Cache read errors                 0 
Aug 11 05:05:30 Forced recaches                   0 
Aug 11 05:05:30 Cache write errors                0 

See CircleCI build pytorch_linux_xenial_py3_6_gcc5_4_ge_config_simple_test (2/2)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun)

Aug 11 05:05:20 ERROR:sccache::server: Compilation failed: Output { status: ExitStatus(ExitStatus(256)), stdout: "", stderr: "/var/lib/jenkins/.cache/torch_extensions/test_compilation_error_formatting/main.cpp: In function \'int main()\':\n/var/lib/jenkins/.cache/torch_extensions/test_compilation_error_formatting/main.cpp:2:23: error: expected \';\' before \'}\' token\n int main() { return 0 }\n ^\n" }
Aug 11 05:05:20 Traceback (most recent call last): 
Aug 11 05:05:20   File "test/run_test.py", line 716, in <module> 
Aug 11 05:05:20     main() 
Aug 11 05:05:20   File "test/run_test.py", line 705, in main 
Aug 11 05:05:20     raise RuntimeError(err) 
Aug 11 05:05:20 RuntimeError: test_quantization failed! 
Aug 11 05:05:20 + cleanup 
Aug 11 05:05:20 + retcode=1 
Aug 11 05:05:20 + set +x 
Aug 11 05:05:20 =================== sccache compilation log =================== 
Aug 11 05:05:20 ERROR:sccache::server: Compilation failed: Output { status: ExitStatus(ExitStatus(256)), stdout: "", stderr: "/var/lib/jenkins/.cache/torch_extensions/test_compilation_error_formatting/main.cpp: In function \'int main()\':\n/var/lib/jenkins/.cache/torch_extensions/test_compilation_error_formatting/main.cpp:2:23: error: expected \';\' before \'}\' token\n int main() { return 0 }\n                       ^\n" } 
Aug 11 05:05:20  
Aug 11 05:05:20 =========== If your build fails, please take a look at the log above for possible reasons =========== 
Aug 11 05:05:20 Compile requests                 65 
Aug 11 05:05:20 Compile requests executed        35 
Aug 11 05:05:20 Cache hits                       27 
Aug 11 05:05:20 Cache misses                      7 
Aug 11 05:05:20 Cache timeouts                    0 
Aug 11 05:05:20 Cache read errors                 0 
Aug 11 05:05:20 Forced recaches                   0 
Aug 11 05:05:20 Cache write errors                0 

ci.pytorch.org: 2 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 21 times.

@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Aug 5, 2020
@nickgg nickgg requested a review from bertmaher August 5, 2020 21:38
Copy link
Contributor

@zheng-xq zheng-xq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great change! A few minor changes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: this class has enough behavior to be put into a class. And mark its members as private.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I simplified this definition a bit, and would like to keep it as a record rather than a class.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: this might get quite expensive for a fairly large program. Maybe add a TODO to remind our-future-selves.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which part gets expensive, comparing indices?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-Blocking: I am perfectly fine with this approach in general. But I would like to point out there are a lot more cases to make it functionally correct in slightly more general cases.

From a certain perspective, the Registerizer move a global memory access to a thread-local memory. This could change the semantics if the memory access has cross-thread dependency. For example: if a global reads really needs to read the information of another atomic global writes, that access really needs to get through the global memory.

This is not likely a problem because we cannot generate that complex a program yet. But we should keep reminding ourselves of the memory semantics change.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, this only addresses a subset of cases where you can push accesses to a scalar. I believe it's currently pessimistic, however, if there are any accesses which may overlap a registerization candidate program-wide we won't do it. So we should be correct always but we'll leave some perf on the table.

The next step is to divide the program into sub-sections which can have distinct registerizations and write them back at the boundaries of those subsections. I have some ideas on that, but it gets more complicated.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change the comments to reflect the test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pardon, what do you mean here? How does this comment not reflect it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought the code refers to A[x], not A[0]. No?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes thanks, good catch.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even in this case, please make sure you cannot replace the registers, if "x" is marked with threadIdx/blockIdx, or in the future, "paralellize".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am assuming that this pass occurs after block/thread axes are flattened down, but I'll add a check to make this explicit.

Copy link
Contributor

@bertmaher bertmaher left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In compiler lingo I think this is often called "scalar replacement", which might be worth mentioning somewhere in the block comment describing this optimization pass :-)

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bertmaher has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@nickgg
Copy link
Contributor Author

nickgg commented Aug 6, 2020

I ran into cases where this did the wrong thing, it needs the Let Stmt PR as well as a few other changes I'll add when this lands.

@nickgg
Copy link
Contributor Author

nickgg commented Aug 6, 2020

@bertmaher thanks! I didn't know the name of this pattern but I figured it was common. Would ScalarReplacer be a better name than Registerizer?

@bertmaher
Copy link
Contributor

@bertmaher thanks! I didn't know the name of this pattern but I figured it was common. Would ScalarReplacer be a better name than Registerizer?

Oh, idk, I kinda like the name "Registerizer" :). But ScalarReplacer would maybe be more typical. Up to you!

@nickgg
Copy link
Contributor Author

nickgg commented Aug 7, 2020

Since I needed another diff to land for this, I ended up rolling the next set of improvements into this change. Some changes in the last push:

  • Moved helpers into header files.
  • Now inserts definition and final store in the closest space to the first and last usage of the access, preventing bad ordering of vars (testRegisterizerAllocs covers this).
  • Now supports registerizing accesses which are only made up of Loads and not Stores, and does not attempt to write the value of the scalar back to the buffer (change to testRegisterizerNoLoads covers this).
  • Now correctly handles cases where the buffer is not initialized in the kernel, and will initialize the scalar by reading the buffer. (testRegisterizerNoInit and testRegisterizerLoadThenStore cover this)
  • Now hoists the definition of the scalar to the highest loop axis that it is not dependent on, meaning we now correctly cover cases where an access appears only inside an inner loop but does not depend on the loop var. (testRegisterizerNoInit and testRegisterizerLoadThenStore cover this too)
  • Now bails out early if there are still GPU Block Idx or Thead Idxs loop options present in the tree. (testRegisterizerParallelized covers this).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought the code refers to A[x], not A[0]. No?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the find() function used somewhere?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I use it in a follow up (currently). I'd like to keep it for now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this member have to be public?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, guess not.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: since we are in the stage of moving around different passes and try different orders. It will be good to list some of the ordering requirement with CudaCodeGen here. For example: this must be invoked after threadIdx flattening, but must happen before pass xyz. It doesn't have to be complete, just enough to remind us where not to move this into.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems useful in general. Why not making it accept both Stmt*?

@nickgg nickgg force-pushed the registerizer branch 2 times, most recently from 768d9ed to 9a30c03 Compare August 10, 2020 21:55
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nickgg has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nickgg has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@nickgg merged this pull request in aabdef5.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged oncall: jit Add this issue/PR to JIT oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants