-
Notifications
You must be signed in to change notification settings - Fork 26.3k
optimize scatter_add performance for gnn usage on CPU #82703
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful links
❌ 15 New FailuresAs of commit 97c4f58 (more details on the Dr. CI page): Expand to see more
🕵️ 15 new failures recognized by patternsThe following CI failures do not appear to be due to upstream breakages
|
|
This PR is effort for optimizing TODO: Delete the following items since perf impact is minor
Move the following items to next step
|
[ghstack-poisoned]
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/82703
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit 4f7a242: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
[ghstack-poisoned]
[ghstack-poisoned]
### Motivation of this PR This PR is targeting at improving performance of `scatter_add` for GNN usage scenarios on PyG. Currently only CPU optimizations is covered. `Message Passing` is the major step in GNN learning which means exchanging/aggregating info between nodes. And from the perf point of view, if the `EdgeIndex` is stored as [2, num_edges], `scatter_reduce` would be a major perf hotspot on current pytorch implementations. To be more specific, in the process of message passing, `scatter_add` is used in a very similar way as `index_select`, except that the `self` tensor is written into while `index_select` is only reading. Therefore, the `index` tensor passed to `scatter_add` is an expanded tensor on dim0, which means all the rest of dims would end up with the same value. ### Algorithm Current impl on scatter would do parallel on the inner dims for such case which would cause bad perf: non-contiguous memory access pattern and non-vectorized. This PR did sorting on the `index` to solve the write conflicts if we directly parallel on dim0. The algorithm is equivalent to: * convert memory format from `COO` to `CSR` * do spmm reduce ### Perf improvement The benchmark comes from https://github.com/pyg-team/pytorch_geometric/tree/master/examples, `python reddit.py` which runs model SAGE on dataset reddit. CPU type: Intel(R) Xeon(R) Gold 6248 CPU @ 2.50GHz ` aten::scatter_add_` has been reduced from **37.797s** to **5.989s**: * breakdown before ``` ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ aten::scatter_add_ 49.00% 37.797s 49.00% 37.797s 41.445ms 912 aten::index_select 19.74% 15.223s 19.74% 15.227s 6.678ms 2280 aten::linear 0.01% 5.706ms 15.04% 11.602s 12.721ms 912 aten::addmm 6.62% 5.108s 7.92% 6.112s 13.403ms 456 aten::matmul 0.00% 2.339ms 7.10% 5.475s 12.006ms 456 ``` * breakdown after ``` ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ aten::index_select 32.41% 14.677s 32.42% 14.681s 6.439ms 2280 aten::linear 0.01% 6.665ms 26.43% 11.968s 13.123ms 912 aten::addmm 11.76% 5.328s 13.76% 6.232s 13.667ms 456 aten::scatter_add_ 13.22% 5.989s 13.22% 5.989s 6.566ms 912 aten::matmul 0.01% 2.303ms 12.63% 5.720s 12.543ms 456 ``` [ghstack-poisoned]
|
@rusty1s could you please help review this one? |
|
Yes, I can take a look ASAP. |
|
cc @mananshah99 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this PR. This is really cool :)
The overall implementation makes sense to me. I guess there might exist some threshold on K and index.numel() when spmm is actually preferred.
In addition, we can also think about wrapping this into a more high-level segment_add implementation which assumes sorted indices as input.
### Motivation of this PR This PR is targeting at improving performance of `scatter_add` for GNN usage scenarios on PyG. Currently only CPU optimizations is covered. `Message Passing` is the major step in GNN learning which means exchanging/aggregating info between nodes. And from the perf point of view, if the `EdgeIndex` is stored as [2, num_edges], `scatter_reduce` would be a major perf hotspot on current pytorch implementations. To be more specific, in the process of message passing, `scatter_add` is used in a very similar way as `index_select`, except that the `self` tensor is written into while `index_select` is only reading. Therefore, the `index` tensor passed to `scatter_add` is an expanded tensor on dim0, which means all the rest of dims would end up with the same value. ### Algorithm Current impl on scatter would do parallel on the inner dims for such case which would cause bad perf: non-contiguous memory access pattern and non-vectorized. This PR did sorting on the `index` to solve the write conflicts if we directly parallel on dim0. The algorithm is equivalent to: * convert memory format from `COO` to `CSR` * do spmm reduce ### Perf improvement The benchmark comes from https://github.com/pyg-team/pytorch_geometric/tree/master/examples, `python reddit.py` which runs model SAGE on dataset reddit. CPU type: Intel(R) Xeon(R) Gold 6248 CPU @ 2.50GHz ` aten::scatter_add_` has been reduced from **37.797s** to **5.989s**: * breakdown before ``` ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ aten::scatter_add_ 49.00% 37.797s 49.00% 37.797s 41.445ms 912 aten::index_select 19.74% 15.223s 19.74% 15.227s 6.678ms 2280 aten::linear 0.01% 5.706ms 15.04% 11.602s 12.721ms 912 aten::addmm 6.62% 5.108s 7.92% 6.112s 13.403ms 456 aten::matmul 0.00% 2.339ms 7.10% 5.475s 12.006ms 456 ``` * breakdown after ``` ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ aten::index_select 32.41% 14.677s 32.42% 14.681s 6.439ms 2280 aten::linear 0.01% 6.665ms 26.43% 11.968s 13.123ms 912 aten::addmm 11.76% 5.328s 13.76% 6.232s 13.667ms 456 aten::scatter_add_ 13.22% 5.989s 13.22% 5.989s 6.566ms 912 aten::matmul 0.01% 2.303ms 12.63% 5.720s 12.543ms 456 ``` [ghstack-poisoned]
|
/easycla As part of the transition to the PyTorch Foundation, this project now requires contributions be covered under the new CLA. See #85559 for additional details. This comment will trigger a new check of this PR. If you are already covered, you will simply see a new "EasyCLA" check that passes. If you are not covered, a bot will leave a new comment with a link to sign. |
### Motivation of this PR This PR is targeting at improving performance of `scatter_add` for GNN usage scenarios on PyG. Currently only CPU optimizations is covered. `Message Passing` is the major step in GNN learning which means exchanging/aggregating info between nodes. And from the perf point of view, if the `EdgeIndex` is stored as [2, num_edges], `scatter_reduce` would be a major perf hotspot on current pytorch implementations. To be more specific, in the process of message passing, `scatter_add` is used in a very similar way as `index_select`, except that the `self` tensor is written into while `index_select` is only reading. Therefore, the `index` tensor passed to `scatter_add` is an expanded tensor on dim0, which means all the rest of dims would end up with the same value. ### Algorithm Current impl on scatter would do parallel on the inner dims for such case which would cause bad perf: non-contiguous memory access pattern and non-vectorized. This PR did sorting on the `index` to solve the write conflicts if we directly parallel on dim0. The algorithm is equivalent to: * convert memory format from `COO` to `CSR` * do spmm reduce ### Perf improvement The benchmark comes from https://github.com/pyg-team/pytorch_geometric/tree/master/examples, `python reddit.py` which runs model SAGE on dataset reddit. CPU type: Intel(R) Xeon(R) Gold 6248 CPU @ 2.50GHz ` aten::scatter_add_` has been reduced from **37.797s** to **5.989s**: * breakdown before ``` ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ aten::scatter_add_ 49.00% 37.797s 49.00% 37.797s 41.445ms 912 aten::index_select 19.74% 15.223s 19.74% 15.227s 6.678ms 2280 aten::linear 0.01% 5.706ms 15.04% 11.602s 12.721ms 912 aten::addmm 6.62% 5.108s 7.92% 6.112s 13.403ms 456 aten::matmul 0.00% 2.339ms 7.10% 5.475s 12.006ms 456 ``` * breakdown after ``` ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ aten::index_select 32.41% 14.677s 32.42% 14.681s 6.439ms 2280 aten::linear 0.01% 6.665ms 26.43% 11.968s 13.123ms 912 aten::addmm 11.76% 5.328s 13.76% 6.232s 13.667ms 456 aten::scatter_add_ 13.22% 5.989s 13.22% 5.989s 6.566ms 912 aten::matmul 0.01% 2.303ms 12.63% 5.720s 12.543ms 456 ``` [ghstack-poisoned]
### Motivation of this PR This PR is targeting at improving performance of `scatter_add` for GNN usage scenarios on PyG. Currently only CPU optimizations is covered. `Message Passing` is the major step in GNN learning which means exchanging/aggregating info between nodes. And from the perf point of view, if the `EdgeIndex` is stored as [2, num_edges], `scatter_reduce` would be a major perf hotspot on current pytorch implementations. To be more specific, in the process of message passing, `scatter_add` is used in a very similar way as `index_select`, except that the `self` tensor is written into while `index_select` is only reading. Therefore, the `index` tensor passed to `scatter_add` is an expanded tensor on dim0, which means all the rest of dims would end up with the same value. ### Algorithm Current impl on scatter would do parallel on the inner dims for such case which would cause bad perf: non-contiguous memory access pattern and non-vectorized. This PR did sorting on the `index` to solve the write conflicts if we directly parallel on dim0. The algorithm is equivalent to: * convert memory format from `COO` to `CSR` * do spmm reduce ### Perf improvement The benchmark comes from https://github.com/pyg-team/pytorch_geometric/tree/master/examples, `python reddit.py` which runs model SAGE on dataset reddit. CPU type: Intel(R) Xeon(R) Gold 6248 CPU @ 2.50GHz ` aten::scatter_add_` has been reduced from **37.797s** to **5.989s**: * breakdown before ``` ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ aten::scatter_add_ 49.00% 37.797s 49.00% 37.797s 41.445ms 912 aten::index_select 19.74% 15.223s 19.74% 15.227s 6.678ms 2280 aten::linear 0.01% 5.706ms 15.04% 11.602s 12.721ms 912 aten::addmm 6.62% 5.108s 7.92% 6.112s 13.403ms 456 aten::matmul 0.00% 2.339ms 7.10% 5.475s 12.006ms 456 ``` * breakdown after ``` ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ aten::index_select 32.41% 14.677s 32.42% 14.681s 6.439ms 2280 aten::linear 0.01% 6.665ms 26.43% 11.968s 13.123ms 912 aten::addmm 11.76% 5.328s 13.76% 6.232s 13.667ms 456 aten::scatter_add_ 13.22% 5.989s 13.22% 5.989s 6.566ms 912 aten::matmul 0.01% 2.303ms 12.63% 5.720s 12.543ms 456 ``` [ghstack-poisoned]
### Motivation of this PR This PR is targeting at improving performance of `scatter_add` for GNN usage scenarios on PyG. Currently only CPU optimizations is covered. `Message Passing` is the major step in GNN learning which means exchanging/aggregating info between nodes. And from the perf point of view, if the `EdgeIndex` is stored as [2, num_edges], `scatter_reduce` would be a major perf hotspot on current pytorch implementations. To be more specific, in the process of message passing, `scatter_add` is used in a very similar way as `index_select`, except that the `self` tensor is written into while `index_select` is only reading. Therefore, the `index` tensor passed to `scatter_add` is an expanded tensor on dim0, which means all the rest of dims would end up with the same value. ### Algorithm Current impl on scatter would do parallel on the inner dims for such case which would cause bad perf: non-contiguous memory access pattern and non-vectorized. This PR did sorting on the `index` to solve the write conflicts if we directly parallel on dim0. The algorithm is equivalent to: * convert memory format from `COO` to `CSR` * do spmm reduce ### Perf improvement The benchmark comes from https://github.com/pyg-team/pytorch_geometric/tree/master/examples, `python reddit.py` which runs model SAGE on dataset reddit. CPU type: Intel(R) Xeon(R) Gold 6248 CPU @ 2.50GHz ` aten::scatter_add_` has been reduced from **37.797s** to **5.989s**: * breakdown before ``` ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ aten::scatter_add_ 49.00% 37.797s 49.00% 37.797s 41.445ms 912 aten::index_select 19.74% 15.223s 19.74% 15.227s 6.678ms 2280 aten::linear 0.01% 5.706ms 15.04% 11.602s 12.721ms 912 aten::addmm 6.62% 5.108s 7.92% 6.112s 13.403ms 456 aten::matmul 0.00% 2.339ms 7.10% 5.475s 12.006ms 456 ``` * breakdown after ``` ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ aten::index_select 32.41% 14.677s 32.42% 14.681s 6.439ms 2280 aten::linear 0.01% 6.665ms 26.43% 11.968s 13.123ms 912 aten::addmm 11.76% 5.328s 13.76% 6.232s 13.667ms 456 aten::scatter_add_ 13.22% 5.989s 13.22% 5.989s 6.566ms 912 aten::matmul 0.01% 2.303ms 12.63% 5.720s 12.543ms 456 ``` cc @VitalyFedyunin jgong5 @XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
### Motivation of this PR This PR is targeting at improving performance of `scatter_add` for GNN usage scenarios on PyG. Currently only CPU optimizations is covered. `Message Passing` is the major step in GNN learning which means exchanging/aggregating info between nodes. And from the perf point of view, if the `EdgeIndex` is stored as [2, num_edges], `scatter_reduce` would be a major perf hotspot on current pytorch implementations. To be more specific, in the process of message passing, `scatter_add` is used in a very similar way as `index_select`, except that the `self` tensor is written into while `index_select` is only reading. Therefore, the `index` tensor passed to `scatter_add` is an expanded tensor on dim0, which means all the rest of dims would end up with the same value. ### Algorithm Current impl on scatter would do parallel on the inner dims for such case which would cause bad perf: non-contiguous memory access pattern and non-vectorized. This PR did sorting on the `index` to solve the write conflicts if we directly parallel on dim0. The algorithm is equivalent to: * convert memory format from `COO` to `CSR` * do spmm reduce ### Perf improvement The benchmark comes from https://github.com/pyg-team/pytorch_geometric/tree/master/examples, `python reddit.py` which runs model SAGE on dataset reddit. CPU type: Intel(R) Xeon(R) Gold 6248 CPU @ 2.50GHz ` aten::scatter_add_` has been reduced from **37.797s** to **5.989s**: * breakdown before ``` ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ aten::scatter_add_ 49.00% 37.797s 49.00% 37.797s 41.445ms 912 aten::index_select 19.74% 15.223s 19.74% 15.227s 6.678ms 2280 aten::linear 0.01% 5.706ms 15.04% 11.602s 12.721ms 912 aten::addmm 6.62% 5.108s 7.92% 6.112s 13.403ms 456 aten::matmul 0.00% 2.339ms 7.10% 5.475s 12.006ms 456 ``` * breakdown after ``` ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ aten::index_select 32.41% 14.677s 32.42% 14.681s 6.439ms 2280 aten::linear 0.01% 6.665ms 26.43% 11.968s 13.123ms 912 aten::addmm 11.76% 5.328s 13.76% 6.232s 13.667ms 456 aten::scatter_add_ 13.22% 5.989s 13.22% 5.989s 6.566ms 912 aten::matmul 0.01% 2.303ms 12.63% 5.720s 12.543ms 456 ``` cc @VitalyFedyunin jgong5 @XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
### Motivation of this PR This PR is targeting at improving performance of `scatter_add` for GNN usage scenarios on PyG. Currently only CPU optimizations is covered. `Message Passing` is the major step in GNN learning which means exchanging/aggregating info between nodes. And from the perf point of view, if the `EdgeIndex` is stored as [2, num_edges], `scatter_reduce` would be a major perf hotspot on current pytorch implementations. To be more specific, in the process of message passing, `scatter_add` is used in a very similar way as `index_select`, except that the `self` tensor is written into while `index_select` is only reading. Therefore, the `index` tensor passed to `scatter_add` is an expanded tensor on dim0, which means all the rest of dims would end up with the same value. ### Algorithm Current impl on scatter would do parallel on the inner dims for such case which would cause bad perf: non-contiguous memory access pattern and non-vectorized. This PR did sorting on the `index` to solve the write conflicts if we directly parallel on dim0. The algorithm is equivalent to: * convert memory format from `COO` to `CSR` * do spmm reduce ### Perf improvement The benchmark comes from https://github.com/pyg-team/pytorch_geometric/tree/master/examples, `python reddit.py` which runs model SAGE on dataset reddit. CPU type: Intel(R) Xeon(R) Gold 6248 CPU @ 2.50GHz ` aten::scatter_add_` has been reduced from **37.797s** to **5.989s**: * breakdown before ``` ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ aten::scatter_add_ 49.00% 37.797s 49.00% 37.797s 41.445ms 912 aten::index_select 19.74% 15.223s 19.74% 15.227s 6.678ms 2280 aten::linear 0.01% 5.706ms 15.04% 11.602s 12.721ms 912 aten::addmm 6.62% 5.108s 7.92% 6.112s 13.403ms 456 aten::matmul 0.00% 2.339ms 7.10% 5.475s 12.006ms 456 ``` * breakdown after ``` ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ aten::index_select 32.41% 14.677s 32.42% 14.681s 6.439ms 2280 aten::linear 0.01% 6.665ms 26.43% 11.968s 13.123ms 912 aten::addmm 11.76% 5.328s 13.76% 6.232s 13.667ms 456 aten::scatter_add_ 13.22% 5.989s 13.22% 5.989s 6.566ms 912 aten::matmul 0.01% 2.303ms 12.63% 5.720s 12.543ms 456 ``` cc @VitalyFedyunin jgong5 @XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
### Motivation of this PR This PR is targeting at improving performance of `scatter_add` for GNN usage scenarios on PyG. Currently only CPU optimizations is covered. `Message Passing` is the major step in GNN learning which means exchanging/aggregating info between nodes. And from the perf point of view, if the `EdgeIndex` is stored as [2, num_edges], `scatter_reduce` would be a major perf hotspot on current pytorch implementations. To be more specific, in the process of message passing, `scatter_add` is used in a very similar way as `index_select`, except that the `self` tensor is written into while `index_select` is only reading. Therefore, the `index` tensor passed to `scatter_add` is an expanded tensor on dim0, which means all the rest of dims would end up with the same value. ### Algorithm Current impl on scatter would do parallel on the inner dims for such case which would cause bad perf: non-contiguous memory access pattern and non-vectorized. This PR did sorting on the `index` to solve the write conflicts if we directly parallel on dim0. The algorithm is equivalent to: * convert memory format from `COO` to `CSR` * do spmm reduce ### Perf improvement The benchmark comes from https://github.com/pyg-team/pytorch_geometric/tree/master/examples, `python reddit.py` which runs model SAGE on dataset reddit. CPU type: Intel(R) Xeon(R) Gold 6248 CPU @ 2.50GHz ` aten::scatter_add_` has been reduced from **37.797s** to **5.989s**: * breakdown before ``` ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ aten::scatter_add_ 49.00% 37.797s 49.00% 37.797s 41.445ms 912 aten::index_select 19.74% 15.223s 19.74% 15.227s 6.678ms 2280 aten::linear 0.01% 5.706ms 15.04% 11.602s 12.721ms 912 aten::addmm 6.62% 5.108s 7.92% 6.112s 13.403ms 456 aten::matmul 0.00% 2.339ms 7.10% 5.475s 12.006ms 456 ``` * breakdown after ``` ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ aten::index_select 32.41% 14.677s 32.42% 14.681s 6.439ms 2280 aten::linear 0.01% 6.665ms 26.43% 11.968s 13.123ms 912 aten::addmm 11.76% 5.328s 13.76% 6.232s 13.667ms 456 aten::scatter_add_ 13.22% 5.989s 13.22% 5.989s 6.566ms 912 aten::matmul 0.01% 2.303ms 12.63% 5.720s 12.543ms 456 ``` cc @VitalyFedyunin jgong5 @XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
ezyang
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just reviewed for the framework side of things, did not review algos
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: This PR is too stale; the last push date was more than 3 days ago. Please rebase and try again. You can rebase by leaving the following comment on this PR: Details for Dev Infra teamRaised by workflow job |
### Motivation of this PR This PR is targeting at improving performance of `scatter_add` for GNN usage scenarios on PyG. Currently only CPU optimizations is covered. `Message Passing` is the major step in GNN learning which means exchanging/aggregating info between nodes. And from the perf point of view, if the `EdgeIndex` is stored as [2, num_edges], `scatter_reduce` would be a major perf hotspot on current pytorch implementations. To be more specific, in the process of message passing, `scatter_add` is used in a very similar way as `index_select`, except that the `self` tensor is written into while `index_select` is only reading. Therefore, the `index` tensor passed to `scatter_add` is an expanded tensor on dim0, which means all the rest of dims would end up with the same value. ### Algorithm Current impl on scatter would do parallel on the inner dims for such case which would cause bad perf: non-contiguous memory access pattern and non-vectorized. This PR did sorting on the `index` to solve the write conflicts if we directly parallel on dim0. The algorithm is equivalent to: * convert memory format from `COO` to `CSR` * do spmm reduce ### Perf improvement The benchmark comes from https://github.com/pyg-team/pytorch_geometric/tree/master/examples, `python reddit.py` which runs model SAGE on dataset reddit. CPU type: Intel(R) Xeon(R) Gold 6248 CPU @ 2.50GHz ` aten::scatter_add_` has been reduced from **37.797s** to **5.989s**: * breakdown before ``` ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ aten::scatter_add_ 49.00% 37.797s 49.00% 37.797s 41.445ms 912 aten::index_select 19.74% 15.223s 19.74% 15.227s 6.678ms 2280 aten::linear 0.01% 5.706ms 15.04% 11.602s 12.721ms 912 aten::addmm 6.62% 5.108s 7.92% 6.112s 13.403ms 456 aten::matmul 0.00% 2.339ms 7.10% 5.475s 12.006ms 456 ``` * breakdown after ``` ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ aten::index_select 32.41% 14.677s 32.42% 14.681s 6.439ms 2280 aten::linear 0.01% 6.665ms 26.43% 11.968s 13.123ms 912 aten::addmm 11.76% 5.328s 13.76% 6.232s 13.667ms 456 aten::scatter_add_ 13.22% 5.989s 13.22% 5.989s 6.566ms 912 aten::matmul 0.01% 2.303ms 12.63% 5.720s 12.543ms 456 ``` cc @VitalyFedyunin jgong5 @XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack:
Motivation of this PR
This PR is targeting at improving performance of
scatter_addfor GNN usage scenarios on PyG. Currently only CPU optimizations is covered.Message Passingis the major step in GNN learning which means exchanging/aggregating info between nodes. And from the perf point of view, if theEdgeIndexis stored as [2, num_edges],scatter_reducewould be a major perf hotspot on current pytorch implementations.To be more specific, in the process of message passing,
scatter_addis used in a very similar way asindex_select, except that theselftensor is written into whileindex_selectis only reading. Therefore, theindextensor passed toscatter_addis an expanded tensor on dim0, which means all the rest of dims would end up with the same value.Algorithm
Current impl on scatter would do parallel on the inner dims for such case which would cause bad perf: non-contiguous memory access pattern and non-vectorized.
This PR did sorting on the
indexto solve the write conflicts if we directly parallel on dim0. The algorithm is equivalent to:COOtoCSRPerf improvement
The benchmark comes from https://github.com/pyg-team/pytorch_geometric/tree/master/examples,
python reddit.pywhich runs model SAGE on dataset reddit.CPU type: Intel(R) Xeon(R) Gold 6248 CPU @ 2.50GHz
aten::scatter_add_has been reduced from 37.797s to 5.989s:cc @VitalyFedyunin @jgong5 @XiaobingSuper @sanchitintel @ashokei @jingxu10