Skip to content

Conversation

@mingfeima
Copy link
Collaborator

@mingfeima mingfeima commented Sep 1, 2022

Stack from ghstack:

Originally cpu/moments_utils.h uses namespace of at::native::utils,
this file contains Vectorized<>, in order to make it properly vectorized
on different archs, need to use anonymous namespace or inline namespace.
Otherwise it would be linked to scalar version of the code.

This PR is to fix vectorization issue from RowwiseMoments which is used to calculate mean and rstd in norm layers.
Attach benchmark data, generally fp32 will get 2-3x speedup and bf16 has larger speedup.

This patch will improves layer_norm (input size 32x128x1024) float32 inference:

  • avx512 single socket: 2.1x
before: LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 0.439 ms; bf16: 2.479 ms
after:  LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 0.210 ms; bf16: 0.770 ms
  • avx512 single core: 3.2x
before: LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 6.308 ms; bf16: 39.765 ms
after:  LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 2.661 ms; bf16: 12.267 ms
  • avx2 single socket: 2.3x
before: LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 1.248 ms; bf16: 8.487 ms
after:  LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 0.540 ms; bf16: 2.030 ms
  • avx2 single core: 2.5x
before: LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 10.792 ms; bf16: 66.366 ms
after:  LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 4.349 ms; bf16: 19.252 ms

Attached some original VTune profiling results here to further indicate the issue:

  1. original bottlenecks
    master_bottleneck

we can see RowwiseMomentsImpl<> takes majority of the runtime here.

  1. Instruction level breakdown of RowwiseMomentsImpl<>
    rowwise_momentum_impl

we can see it's all scalar instructions here.

  1. after the fix, the bottlenecks
    fixed_bottleneck

getting better.

  1. after the fix, Instruction level breakdown of RowwiseMomentsImpl<>
    fixed_rowwsie_momentum_impl

now it is all vectorized instructions.

cc @VitalyFedyunin @jgong5 @XiaobingSuper @sanchitintel @ashokei @jingxu10

Originally `cpu/moments_utils.h` uses namespace of at::native::utils,
this file contains `Vectorized<>`, in order to make it properly vectorized
on different archs, need to use anonymous namespace or inline namespace.
Otherwise it would be linked to scalar version of the code.

This PR is to fix vectorization issue from `RowwiseMoments` which is used to calculate `mean` and `rstd` in norm layers.
Attach benchmark data, generally fp32 will get 2-3x speedup and bf16 has larger speedup.

This patch will improves layer_norm (input size 32x128x1024) float32 inference:
* avx512 single socket: 2.1x
```bash
before: LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 0.439 ms; bf16: 2.479 ms
after:  LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 0.210 ms; bf16: 0.770 ms
```
* avx512 single core: 3.2x
```bash
before: LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 6.308 ms; bf16: 39.765 ms
after:  LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 2.661 ms; bf16: 12.267 ms
```
* avx2 single socket: 2.3x
```bash
before: LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 1.248 ms; bf16: 8.487 ms
after:  LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 0.540 ms; bf16: 2.030 ms
```
* avx2 single core: 2.5x
```bash
before: LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 10.792 ms; bf16: 66.366 ms
after:  LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 4.349 ms; bf16: 19.252 ms
```

Attached some original VTune profiling results here to further indicate the issue:

1. original bottlenecks
![master_bottleneck](https://user-images.githubusercontent.com/20233731/180125611-deed41b7-dd2e-4437-a7d9-6ad0096e5850.png)

we can see `RowwiseMomentsImpl<>` takes majority of the runtime here.

2. Instruction level breakdown of `RowwiseMomentsImpl<>`
![rowwise_momentum_impl](https://user-images.githubusercontent.com/20233731/180125759-a3b48bc4-8e54-4219-92b4-defde5e86046.png)

we can see it's all **scalar** instructions here.

3. after the fix, the bottlenecks
![fixed_bottleneck](https://user-images.githubusercontent.com/20233731/180125880-8d08eb1b-af09-4f80-ae58-80215365d407.png)

getting better.

4. after the fix, Instruction level breakdown of `RowwiseMomentsImpl<>`
![fixed_rowwsie_momentum_impl](https://user-images.githubusercontent.com/20233731/180125989-b45db4ad-e6ed-460a-8d51-74fbeecf8b02.png)

now it is all **vectorized** instructions.

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Sep 1, 2022

🔗 Helpful links

✅ No Failures (0 Pending)

As of commit 4ba0629 (more details on the Dr. CI page):

Expand to see more

💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@mingfeima
Copy link
Collaborator Author

mingfeima commented Sep 1, 2022

replacement of #81849.
need to fix this performance regression pytorch/benchmark#1099

Originally `cpu/moments_utils.h` uses namespace of at::native::utils,
this file contains `Vectorized<>`, in order to make it properly vectorized
on different archs, need to use anonymous namespace or inline namespace.
Otherwise it would be linked to scalar version of the code.

This PR is to fix vectorization issue from `RowwiseMoments` which is used to calculate `mean` and `rstd` in norm layers.
Attach benchmark data, generally fp32 will get 2-3x speedup and bf16 has larger speedup.

This patch will improves layer_norm (input size 32x128x1024) float32 inference:
* avx512 single socket: 2.1x
```bash
before: LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 0.439 ms; bf16: 2.479 ms
after:  LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 0.210 ms; bf16: 0.770 ms
```
* avx512 single core: 3.2x
```bash
before: LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 6.308 ms; bf16: 39.765 ms
after:  LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 2.661 ms; bf16: 12.267 ms
```
* avx2 single socket: 2.3x
```bash
before: LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 1.248 ms; bf16: 8.487 ms
after:  LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 0.540 ms; bf16: 2.030 ms
```
* avx2 single core: 2.5x
```bash
before: LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 10.792 ms; bf16: 66.366 ms
after:  LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 4.349 ms; bf16: 19.252 ms
```

Attached some original VTune profiling results here to further indicate the issue:

1. original bottlenecks
![master_bottleneck](https://user-images.githubusercontent.com/20233731/180125611-deed41b7-dd2e-4437-a7d9-6ad0096e5850.png)

we can see `RowwiseMomentsImpl<>` takes majority of the runtime here.

2. Instruction level breakdown of `RowwiseMomentsImpl<>`
![rowwise_momentum_impl](https://user-images.githubusercontent.com/20233731/180125759-a3b48bc4-8e54-4219-92b4-defde5e86046.png)

we can see it's all **scalar** instructions here.

3. after the fix, the bottlenecks
![fixed_bottleneck](https://user-images.githubusercontent.com/20233731/180125880-8d08eb1b-af09-4f80-ae58-80215365d407.png)

getting better.

4. after the fix, Instruction level breakdown of `RowwiseMomentsImpl<>`
![fixed_rowwsie_momentum_impl](https://user-images.githubusercontent.com/20233731/180125989-b45db4ad-e6ed-460a-8d51-74fbeecf8b02.png)

now it is all **vectorized** instructions.

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Sep 8, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/84404

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 229cf48:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 8, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/84404

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@mingfeima
Copy link
Collaborator Author

since pytorch/benchmark#1099 has been identifies as false alarm, shall we proceed to review this PR again? @malfet, @frank-wei

@mingfeima mingfeima marked this pull request as ready for review September 21, 2022 05:58
Copy link
Contributor

@malfet malfet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, inline namespace concept seems dangerous to me, as I'm not sure I understand how it will guarantee that symbols from say avx512 namespace will not get included from avx2 -only code? Perhaps you just need to add a regular namespace and call utils::CPU_CAPABILITY::RowwiseMoments?

@mingfeima
Copy link
Collaborator Author

mingfeima commented Sep 23, 2022

Hmm, inline namespace concept seems dangerous to me, as I'm not sure I understand how it will guarantee that symbols from say avx512 namespace will not get included from avx2 -only code? Perhaps you just need to add a regular namespace and call utils::CPU_CAPABILITY::RowwiseMoments?

Initially, all the CPU kernels under aten/src/ATen/native/cpu which requires vectorization uses anonymous namespaces, which will make the func static and linked to different assembly for scalar/avx2/avx512. For example, like the CatKernel here.

Later on some kernels are changed to use inline namespace, for example, like this one: CopyKernel. Sure this will also do the job, but honestly I'm not sure why this is introduced at the first place ...

@malfet Is it OK I change this file back to anonymous namespaces ? Right now, most of the CPU kernels are still written in this way.

[Edit]: I have verified that both inline namespace and anonymous namespace can properly vectorized the code.

@mingfeima mingfeima requested a review from malfet September 23, 2022 02:40
Originally `cpu/moments_utils.h` uses namespace of at::native::utils,
this file contains `Vectorized<>`, in order to make it properly vectorized
on different archs, need to use anonymous namespace or inline namespace.
Otherwise it would be linked to scalar version of the code.

This PR is to fix vectorization issue from `RowwiseMoments` which is used to calculate `mean` and `rstd` in norm layers.
Attach benchmark data, generally fp32 will get 2-3x speedup and bf16 has larger speedup.

This patch will improves layer_norm (input size 32x128x1024) float32 inference:
* avx512 single socket: 2.1x
```bash
before: LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 0.439 ms; bf16: 2.479 ms
after:  LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 0.210 ms; bf16: 0.770 ms
```
* avx512 single core: 3.2x
```bash
before: LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 6.308 ms; bf16: 39.765 ms
after:  LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 2.661 ms; bf16: 12.267 ms
```
* avx2 single socket: 2.3x
```bash
before: LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 1.248 ms; bf16: 8.487 ms
after:  LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 0.540 ms; bf16: 2.030 ms
```
* avx2 single core: 2.5x
```bash
before: LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 10.792 ms; bf16: 66.366 ms
after:  LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 4.349 ms; bf16: 19.252 ms
```

Attached some original VTune profiling results here to further indicate the issue:

1. original bottlenecks
![master_bottleneck](https://user-images.githubusercontent.com/20233731/180125611-deed41b7-dd2e-4437-a7d9-6ad0096e5850.png)

we can see `RowwiseMomentsImpl<>` takes majority of the runtime here.

2. Instruction level breakdown of `RowwiseMomentsImpl<>`
![rowwise_momentum_impl](https://user-images.githubusercontent.com/20233731/180125759-a3b48bc4-8e54-4219-92b4-defde5e86046.png)

we can see it's all **scalar** instructions here.

3. after the fix, the bottlenecks
![fixed_bottleneck](https://user-images.githubusercontent.com/20233731/180125880-8d08eb1b-af09-4f80-ae58-80215365d407.png)

getting better.

4. after the fix, Instruction level breakdown of `RowwiseMomentsImpl<>`
![fixed_rowwsie_momentum_impl](https://user-images.githubusercontent.com/20233731/180125989-b45db4ad-e6ed-460a-8d51-74fbeecf8b02.png)

now it is all **vectorized** instructions.

[ghstack-poisoned]
Originally `cpu/moments_utils.h` uses namespace of at::native::utils,
this file contains `Vectorized<>`, in order to make it properly vectorized
on different archs, need to use anonymous namespace or inline namespace.
Otherwise it would be linked to scalar version of the code.

This PR is to fix vectorization issue from `RowwiseMoments` which is used to calculate `mean` and `rstd` in norm layers.
Attach benchmark data, generally fp32 will get 2-3x speedup and bf16 has larger speedup.

This patch will improves layer_norm (input size 32x128x1024) float32 inference:
* avx512 single socket: 2.1x
```bash
before: LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 0.439 ms; bf16: 2.479 ms
after:  LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 0.210 ms; bf16: 0.770 ms
```
* avx512 single core: 3.2x
```bash
before: LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 6.308 ms; bf16: 39.765 ms
after:  LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 2.661 ms; bf16: 12.267 ms
```
* avx2 single socket: 2.3x
```bash
before: LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 1.248 ms; bf16: 8.487 ms
after:  LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 0.540 ms; bf16: 2.030 ms
```
* avx2 single core: 2.5x
```bash
before: LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 10.792 ms; bf16: 66.366 ms
after:  LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 4.349 ms; bf16: 19.252 ms
```

Attached some original VTune profiling results here to further indicate the issue:

1. original bottlenecks
![master_bottleneck](https://user-images.githubusercontent.com/20233731/180125611-deed41b7-dd2e-4437-a7d9-6ad0096e5850.png)

we can see `RowwiseMomentsImpl<>` takes majority of the runtime here.

2. Instruction level breakdown of `RowwiseMomentsImpl<>`
![rowwise_momentum_impl](https://user-images.githubusercontent.com/20233731/180125759-a3b48bc4-8e54-4219-92b4-defde5e86046.png)

we can see it's all **scalar** instructions here.

3. after the fix, the bottlenecks
![fixed_bottleneck](https://user-images.githubusercontent.com/20233731/180125880-8d08eb1b-af09-4f80-ae58-80215365d407.png)

getting better.

4. after the fix, Instruction level breakdown of `RowwiseMomentsImpl<>`
![fixed_rowwsie_momentum_impl](https://user-images.githubusercontent.com/20233731/180125989-b45db4ad-e6ed-460a-8d51-74fbeecf8b02.png)

now it is all **vectorized** instructions.

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

/easycla

As part of the transition to the PyTorch Foundation, this project now requires contributions be covered under the new CLA. See #85559 for additional details.

This comment will trigger a new check of this PR. If you are already covered, you will simply see a new "EasyCLA" check that passes. If you are not covered, a bot will leave a new comment with a link to sign.

@linux-foundation-easycla
Copy link

linux-foundation-easycla bot commented Oct 4, 2022

CLA Signed

The committers listed above are authorized under a signed CLA.

CaoE added a commit that referenced this pull request Nov 17, 2022
…and GroupNorm"




This PR is cherry-picked from #84404 ~ #81852.

[ghstack-poisoned]
CaoE added a commit that referenced this pull request Nov 17, 2022


This PR is cherry-picked from #84404 ~ #81852.

[ghstack-poisoned]
CaoE added a commit that referenced this pull request Nov 17, 2022
…and GroupNorm"




This PR is cherry-picked from #84404 ~ #81852.

[ghstack-poisoned]
CaoE added a commit that referenced this pull request Nov 17, 2022


This PR is cherry-picked from #84404 ~ #81852.

[ghstack-poisoned]
CaoE added a commit that referenced this pull request Nov 22, 2022
…and GroupNorm"




This PR is cherry-picked from #84404 ~ #81852.

[ghstack-poisoned]
CaoE added a commit that referenced this pull request Nov 22, 2022


This PR is cherry-picked from #84404 ~ #81852.

[ghstack-poisoned]
Originally `cpu/moments_utils.h` uses namespace of at::native::utils,
this file contains `Vectorized<>`, in order to make it properly vectorized
on different archs, need to use anonymous namespace or inline namespace.
Otherwise it would be linked to scalar version of the code.

This PR is to fix vectorization issue from `RowwiseMoments` which is used to calculate `mean` and `rstd` in norm layers.
Attach benchmark data, generally fp32 will get 2-3x speedup and bf16 has larger speedup.

This patch will improves layer_norm (input size 32x128x1024) float32 inference:
* avx512 single socket: 2.1x
```bash
before: LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 0.439 ms; bf16: 2.479 ms
after:  LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 0.210 ms; bf16: 0.770 ms
```
* avx512 single core: 3.2x
```bash
before: LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 6.308 ms; bf16: 39.765 ms
after:  LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 2.661 ms; bf16: 12.267 ms
```
* avx2 single socket: 2.3x
```bash
before: LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 1.248 ms; bf16: 8.487 ms
after:  LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 0.540 ms; bf16: 2.030 ms
```
* avx2 single core: 2.5x
```bash
before: LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 10.792 ms; bf16: 66.366 ms
after:  LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 4.349 ms; bf16: 19.252 ms
```

Attached some original VTune profiling results here to further indicate the issue:

1. original bottlenecks
![master_bottleneck](https://user-images.githubusercontent.com/20233731/180125611-deed41b7-dd2e-4437-a7d9-6ad0096e5850.png)

we can see `RowwiseMomentsImpl<>` takes majority of the runtime here.

2. Instruction level breakdown of `RowwiseMomentsImpl<>`
![rowwise_momentum_impl](https://user-images.githubusercontent.com/20233731/180125759-a3b48bc4-8e54-4219-92b4-defde5e86046.png)

we can see it's all **scalar** instructions here.

3. after the fix, the bottlenecks
![fixed_bottleneck](https://user-images.githubusercontent.com/20233731/180125880-8d08eb1b-af09-4f80-ae58-80215365d407.png)

getting better.

4. after the fix, Instruction level breakdown of `RowwiseMomentsImpl<>`
![fixed_rowwsie_momentum_impl](https://user-images.githubusercontent.com/20233731/180125989-b45db4ad-e6ed-460a-8d51-74fbeecf8b02.png)

now it is all **vectorized** instructions.

[ghstack-poisoned]
@github-actions github-actions bot added the module: cpu CPU specific problem (e.g., perf, algorithm) label Nov 28, 2022
@mingfeima mingfeima added the topic: not user facing topic category label Nov 28, 2022
@mingfeima mingfeima requested a review from jgong5 November 28, 2022 04:51
Copy link
Collaborator

@jgong5 jgong5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, inline namespace concept seems dangerous to me, as I'm not sure I understand how it will guarantee that symbols from say avx512 namespace will not get included from avx2 -only code? Perhaps you just need to add a regular namespace and call utils::CPU_CAPABILITY::RowwiseMoments?

Later on some kernels are changed to use inline namespace, for example, like this one: CopyKernel. Sure this will also do the job, but honestly I'm not sure why this is introduced at the first place ...

My understanding is that inline namespace is preferred for functions defined in the header files (e.g., the moment_utils.h in this PR). With this, there won't be duplicated definitions for source files including it. For functions defined in source files, using anonymous namespace should be fine and seems most of the PyTorch source files follow this. The CopyKernel seems like an exception since the functions (e.g., direct_copy_kernel) of CopyKernel are also exposed directly in the header file and used by other kernels. I would suggest we use inline namespace for moment_utils.h.

Originally `cpu/moments_utils.h` uses namespace of at::native::utils,
this file contains `Vectorized<>`, in order to make it properly vectorized
on different archs, need to use anonymous namespace or inline namespace.
Otherwise it would be linked to scalar version of the code.

This PR is to fix vectorization issue from `RowwiseMoments` which is used to calculate `mean` and `rstd` in norm layers.
Attach benchmark data, generally fp32 will get 2-3x speedup and bf16 has larger speedup.

This patch will improves layer_norm (input size 32x128x1024) float32 inference:
* avx512 single socket: 2.1x
```bash
before: LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 0.439 ms; bf16: 2.479 ms
after:  LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 0.210 ms; bf16: 0.770 ms
```
* avx512 single core: 3.2x
```bash
before: LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 6.308 ms; bf16: 39.765 ms
after:  LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 2.661 ms; bf16: 12.267 ms
```
* avx2 single socket: 2.3x
```bash
before: LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 1.248 ms; bf16: 8.487 ms
after:  LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 0.540 ms; bf16: 2.030 ms
```
* avx2 single core: 2.5x
```bash
before: LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 10.792 ms; bf16: 66.366 ms
after:  LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 4.349 ms; bf16: 19.252 ms
```

Attached some original VTune profiling results here to further indicate the issue:

1. original bottlenecks
![master_bottleneck](https://user-images.githubusercontent.com/20233731/180125611-deed41b7-dd2e-4437-a7d9-6ad0096e5850.png)

we can see `RowwiseMomentsImpl<>` takes majority of the runtime here.

2. Instruction level breakdown of `RowwiseMomentsImpl<>`
![rowwise_momentum_impl](https://user-images.githubusercontent.com/20233731/180125759-a3b48bc4-8e54-4219-92b4-defde5e86046.png)

we can see it's all **scalar** instructions here.

3. after the fix, the bottlenecks
![fixed_bottleneck](https://user-images.githubusercontent.com/20233731/180125880-8d08eb1b-af09-4f80-ae58-80215365d407.png)

getting better.

4. after the fix, Instruction level breakdown of `RowwiseMomentsImpl<>`
![fixed_rowwsie_momentum_impl](https://user-images.githubusercontent.com/20233731/180125989-b45db4ad-e6ed-460a-8d51-74fbeecf8b02.png)

now it is all **vectorized** instructions.

cc @VitalyFedyunin jgong5 @XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
@mingfeima mingfeima requested a review from jgong5 November 29, 2022 05:19
@mingfeima
Copy link
Collaborator Author

@jgong5 updated, change back to inline namespace

Originally `cpu/moments_utils.h` uses namespace of at::native::utils,
this file contains `Vectorized<>`, in order to make it properly vectorized
on different archs, need to use anonymous namespace or inline namespace.
Otherwise it would be linked to scalar version of the code.

This PR is to fix vectorization issue from `RowwiseMoments` which is used to calculate `mean` and `rstd` in norm layers.
Attach benchmark data, generally fp32 will get 2-3x speedup and bf16 has larger speedup.

This patch will improves layer_norm (input size 32x128x1024) float32 inference:
* avx512 single socket: 2.1x
```bash
before: LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 0.439 ms; bf16: 2.479 ms
after:  LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 0.210 ms; bf16: 0.770 ms
```
* avx512 single core: 3.2x
```bash
before: LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 6.308 ms; bf16: 39.765 ms
after:  LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 2.661 ms; bf16: 12.267 ms
```
* avx2 single socket: 2.3x
```bash
before: LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 1.248 ms; bf16: 8.487 ms
after:  LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 0.540 ms; bf16: 2.030 ms
```
* avx2 single core: 2.5x
```bash
before: LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 10.792 ms; bf16: 66.366 ms
after:  LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 4.349 ms; bf16: 19.252 ms
```

Attached some original VTune profiling results here to further indicate the issue:

1. original bottlenecks
![master_bottleneck](https://user-images.githubusercontent.com/20233731/180125611-deed41b7-dd2e-4437-a7d9-6ad0096e5850.png)

we can see `RowwiseMomentsImpl<>` takes majority of the runtime here.

2. Instruction level breakdown of `RowwiseMomentsImpl<>`
![rowwise_momentum_impl](https://user-images.githubusercontent.com/20233731/180125759-a3b48bc4-8e54-4219-92b4-defde5e86046.png)

we can see it's all **scalar** instructions here.

3. after the fix, the bottlenecks
![fixed_bottleneck](https://user-images.githubusercontent.com/20233731/180125880-8d08eb1b-af09-4f80-ae58-80215365d407.png)

getting better.

4. after the fix, Instruction level breakdown of `RowwiseMomentsImpl<>`
![fixed_rowwsie_momentum_impl](https://user-images.githubusercontent.com/20233731/180125989-b45db4ad-e6ed-460a-8d51-74fbeecf8b02.png)

now it is all **vectorized** instructions.

cc @VitalyFedyunin jgong5 @XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
@mingfeima
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 30, 2022
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

CaoE added a commit that referenced this pull request Dec 5, 2022
…and GroupNorm"




This PR is cherry-picked from #84404 ~ #81852.

[ghstack-poisoned]
CaoE added a commit that referenced this pull request Dec 5, 2022


This PR is cherry-picked from #84404 ~ #81852.

[ghstack-poisoned]
CaoE added a commit that referenced this pull request Dec 9, 2022
…and GroupNorm"




This PR is cherry-picked from #84404 ~ #81852.

[ghstack-poisoned]
CaoE added a commit that referenced this pull request Dec 9, 2022


This PR is cherry-picked from #84404 ~ #81852.

[ghstack-poisoned]
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Dec 10, 2022
Originally `cpu/moments_utils.h` uses namespace of at::native::utils,
this file contains `Vectorized<>`, in order to make it properly vectorized
on different archs, need to use anonymous namespace or inline namespace.
Otherwise it would be linked to scalar version of the code.

This PR is to fix vectorization issue from `RowwiseMoments` which is used to calculate `mean` and `rstd` in norm layers.
Attach benchmark data, generally fp32 will get 2-3x speedup and bf16 has larger speedup.

This patch will improves layer_norm (input size 32x128x1024) float32 inference:
* avx512 single socket: 2.1x
```bash
before: LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 0.439 ms; bf16: 2.479 ms
after:  LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 0.210 ms; bf16: 0.770 ms
```
* avx512 single core: 3.2x
```bash
before: LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 6.308 ms; bf16: 39.765 ms
after:  LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 2.661 ms; bf16: 12.267 ms
```
* avx2 single socket: 2.3x
```bash
before: LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 1.248 ms; bf16: 8.487 ms
after:  LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 0.540 ms; bf16: 2.030 ms
```
* avx2 single core: 2.5x
```bash
before: LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 10.792 ms; bf16: 66.366 ms
after:  LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 4.349 ms; bf16: 19.252 ms
```

Attached some original VTune profiling results here to further indicate the issue:

1. original bottlenecks
![master_bottleneck](https://user-images.githubusercontent.com/20233731/180125611-deed41b7-dd2e-4437-a7d9-6ad0096e5850.png)

we can see `RowwiseMomentsImpl<>` takes majority of the runtime here.

2. Instruction level breakdown of `RowwiseMomentsImpl<>`
![rowwise_momentum_impl](https://user-images.githubusercontent.com/20233731/180125759-a3b48bc4-8e54-4219-92b4-defde5e86046.png)

we can see it's all **scalar** instructions here.

3. after the fix, the bottlenecks
![fixed_bottleneck](https://user-images.githubusercontent.com/20233731/180125880-8d08eb1b-af09-4f80-ae58-80215365d407.png)

getting better.

4. after the fix, Instruction level breakdown of `RowwiseMomentsImpl<>`
![fixed_rowwsie_momentum_impl](https://user-images.githubusercontent.com/20233731/180125989-b45db4ad-e6ed-460a-8d51-74fbeecf8b02.png)

now it is all **vectorized** instructions.

Pull Request resolved: pytorch#84404
Approved by: https://github.com/jgong5
CaoE added a commit that referenced this pull request Dec 13, 2022
…and GroupNorm"




This PR is cherry-picked from #84404 ~ #81852.

cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
CaoE added a commit that referenced this pull request Dec 13, 2022


This PR is cherry-picked from #84404 ~ #81852.

cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
@facebook-github-bot facebook-github-bot deleted the gh/mingfeima/86/head branch June 8, 2023 18:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request cla signed Merged module: cpu CPU specific problem (e.g., perf, algorithm) open source topic: not user facing topic category

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

8 participants