[Inductor] optimize scalar welford_reduce#162709
[Inductor] optimize scalar welford_reduce#162709jiayisunx wants to merge 12 commits intogh/jiayisunx/77/basefrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/162709
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (5 Unrelated Failures)As of commit a0b880b with merge base 8cf0bdd ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
BROKEN TRUNK - The following jobs failed but was present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
torch/csrc/inductor/cpp_prefix.h
Outdated
| } | ||
|
|
||
| template <typename T, uint64_t kChunkSize> | ||
| template <typename T, typename S, uint64_t kChunkSize> |
There was a problem hiding this comment.
Can we use IsVecType to help get the scalar type and avoid using a new typename S ?
There was a problem hiding this comment.
Done, thanks for your comment!
| # acc helper is not used for scalar welford_reduce | ||
| if reduction_type == "welford_reduce": | ||
| return not use_scalar | ||
| return True |
There was a problem hiding this comment.
Should we also determine this based on the size of welford_reduce?
There was a problem hiding this comment.
Thanks for your comment! WelfordHelper helps two things: 1. Save the reciprocal of weights to avoid redundant divisions. 2. Save the welford stack, which is used to combine welford reduction with cascade summation. Because the first one is beneficial for performance, so WelfordHelper is used by default. Furthermore, this PR only optimizes the scalar welford_reduce implementation; I don't intend to change the behavior of vec welford_reduce.
test/inductor/test_cpu_repro.py
Outdated
| actual = compiled_m(x) | ||
| self.assertEqual(expected, actual) | ||
|
|
||
| # test scalar welford_reduce |
There was a problem hiding this comment.
It's better to reuse the code of this test.
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
ghstack-source-id: 985346a Pull Request resolved: pytorch/pytorch#162709
**Summary:** Optimize scalar welford_reduce implementation, combining Welford algorithm with cascade sum to improve numerical stability. Specifically: 1. Use Welford algorithm to compute mean and variance. 2. Use cascade summation when computing sum over input for both mean and variance. **Example:** Take pytorch#141541 as an example: ``` import torch import torch.nn as nn torch.manual_seed(0) class Model(nn.Module): def __init__(self): super().__init__() self.gn = nn.GroupNorm(num_groups=32, num_channels=32) def forward(self, x): return self.gn(x) model = Model().eval() x = torch.randn(1, 32, 128, 128, 128) with torch.no_grad(): output = model(x) with torch._inductor.config.patch({"cpp.simdlen": 0}): c_model = torch.compile(model) c_output = c_model(x) print(torch.max(torch.abs(output - c_output))) print(torch.allclose(output, c_output, 1.3e-6, 1e-5)) ``` **logs** - before ``` tensor(0.0005) False ``` - After ``` tensor(1.4305e-06) True ``` **Generated code:** - before ``` cpp_fused_native_group_norm_0 = async_compile.cpp_pybinding(['float*', 'float*', 'const float*', 'const float*', 'const float*', 'float*'], ''' #include <torch/csrc/inductor/cpp_prefix.h> extern "C" void kernel(float* in_out_ptr0, float* in_out_ptr1, const float* in_ptr0, const float* in_ptr1, const float* in_ptr2, float* out_ptr2) { auto out_ptr1 = in_out_ptr0; auto out_ptr0 = in_out_ptr1; { #pragma GCC ivdep for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(32L); x0+=static_cast<int64_t>(1L)) { { Welford<float> tmp_acc0 = Welford<float>(); Welford<float> tmp_acc0_arr[4]; for (int i = 0; i < 4; i++) { tmp_acc0_arr[i] = Welford<float>(); } #pragma omp parallel num_threads(4) { int tid = omp_get_thread_num(); Welford<float> tmp_acc0_local = Welford<float>(); #pragma omp for for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(2097152L); x1+=static_cast<int64_t>(1L)) { { { auto tmp0 = in_ptr0[static_cast<int64_t>(x1 + 2097152L*x0)]; tmp_acc0_local = welford_combine(tmp_acc0_local, tmp0); } } } tmp_acc0_arr[tid] = tmp_acc0_local; } for (int tid = 0; tid < 4; tid++) { tmp_acc0 = welford_combine(tmp_acc0, tmp_acc0_arr[tid]); } in_out_ptr1[static_cast<int64_t>(x0)] = tmp_acc0.mean; in_out_ptr0[static_cast<int64_t>(x0)] = tmp_acc0.m2; } } } { #pragma GCC ivdep for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(32L); x0+=static_cast<int64_t>(1L)) { { { auto tmp0 = out_ptr1[static_cast<int64_t>(x0)]; auto tmp6 = in_ptr1[static_cast<int64_t>(x0)]; auto tmp8 = out_ptr0[static_cast<int64_t>(x0)]; auto tmp11 = in_ptr2[static_cast<int64_t>(x0)]; auto tmp1 = static_cast<float>(2097152.0); auto tmp2 = tmp0 / tmp1; auto tmp3 = static_cast<float>(1e-05); auto tmp4 = float(tmp2 + tmp3); auto tmp5 = 1 / std::sqrt(tmp4); auto tmp7 = float(tmp5 * tmp6); auto tmp9 = decltype(tmp8)(-tmp8); auto tmp10 = float(tmp9 * tmp7); auto tmp12 = float(tmp10 + tmp11); in_out_ptr0[static_cast<int64_t>(x0)] = tmp7; in_out_ptr1[static_cast<int64_t>(x0)] = tmp12; } } } } #pragma omp parallel num_threads(4) { int tid = omp_get_thread_num(); { #pragma omp for for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(32L); x0+=static_cast<int64_t>(1L)) { #pragma GCC ivdep for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(2097152L); x1+=static_cast<int64_t>(1L)) { { { auto tmp0 = in_ptr0[static_cast<int64_t>(x1 + 2097152L*x0)]; auto tmp1 = in_out_ptr0[static_cast<int64_t>(x0)]; auto tmp3 = in_out_ptr1[static_cast<int64_t>(x0)]; auto tmp2 = float(tmp0 * tmp1); auto tmp4 = float(tmp2 + tmp3); out_ptr2[static_cast<int64_t>(x1 + 2097152L*x0)] = tmp4; } } } } } } } ''') async_compile.wait(globals()) del async_compile class Runner: def __init__(self, partitions): self.partitions = partitions def recursively_apply_fns(self, fns): new_callables = [] for fn, c in zip(fns, self.partitions): new_callables.append(fn(c)) self.partitions = new_callables def call(self, args): arg0_1, arg1_1, arg2_1 = args args.clear() assert_size_stride(arg0_1, (32, ), (1, )) assert_size_stride(arg1_1, (32, ), (1, )) assert_size_stride(arg2_1, (1, 32, 128, 128, 128), (67108864, 2097152, 16384, 128, 1)) buf0 = empty_strided_cpu((1, 32, 1, 1), (32, 1, 32, 32), torch.float32) buf1 = empty_strided_cpu((1, 32, 1, 1), (32, 1, 32, 32), torch.float32) buf3 = reinterpret_tensor(buf1, (1, 32, 1, 1), (32, 1, 1, 1), 0); del buf1 # reuse buf4 = reinterpret_tensor(buf0, (1, 32, 1, 1), (32, 1, 1, 1), 0); del buf0 # reuse buf5 = empty_strided_cpu((1, 32, 128, 128, 128), (67108864, 2097152, 16384, 128, 1), torch.float32) # [Provenance debug handles] cpp_fused_native_group_norm_0:1 cpp_fused_native_group_norm_0(buf3, buf4, arg2_1, arg0_1, arg1_1, buf5) del arg0_1 del arg1_1 del arg2_1 return (buf5, ) ``` - After ``` cpp_fused_native_group_norm_0 = async_compile.cpp_pybinding(['float*', 'float*', 'const float*', 'const float*', 'const float*', 'float*'], ''' #include <torch/csrc/inductor/cpp_prefix.h> extern "C" void kernel(float* in_out_ptr0, float* in_out_ptr1, const float* in_ptr0, const float* in_ptr1, const float* in_ptr2, float* out_ptr2) { auto out_ptr1 = in_out_ptr0; auto out_ptr0 = in_out_ptr1; { #pragma GCC ivdep for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(32L); x0+=static_cast<int64_t>(1L)) { { Welford<float> tmp_acc0 = Welford<float>(); Welford<float> tmp_acc0_arr[4]; for (int i = 0; i < 4; i++) { tmp_acc0_arr[i] = Welford<float>(); } #pragma omp parallel num_threads(4) { int tid = omp_get_thread_num(); WelfordHelper<float, float, 4096> scalar_welford_helper0(static_cast<int64_t>(524288L)); Welford<float> tmp_acc0_local = Welford<float>(); #pragma omp for for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(2097152L); x1+=static_cast<int64_t>(1L)) { { { auto tmp0 = in_ptr0[static_cast<int64_t>(x1 + 2097152L*x0)]; tmp_acc0_local = welford_combine(tmp_acc0_local, tmp0, &scalar_welford_helper0); } } } tmp_acc0_local = welford_combine(tmp_acc0_local, &scalar_welford_helper0); tmp_acc0_arr[tid] = tmp_acc0_local; } for (int tid = 0; tid < 4; tid++) { tmp_acc0 = welford_combine(tmp_acc0, tmp_acc0_arr[tid]); } in_out_ptr1[static_cast<int64_t>(x0)] = tmp_acc0.mean; in_out_ptr0[static_cast<int64_t>(x0)] = tmp_acc0.m2; } } } { #pragma GCC ivdep for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(32L); x0+=static_cast<int64_t>(1L)) { { { auto tmp0 = out_ptr1[static_cast<int64_t>(x0)]; auto tmp6 = in_ptr1[static_cast<int64_t>(x0)]; auto tmp8 = out_ptr0[static_cast<int64_t>(x0)]; auto tmp11 = in_ptr2[static_cast<int64_t>(x0)]; auto tmp1 = static_cast<float>(2097152.0); auto tmp2 = tmp0 / tmp1; auto tmp3 = static_cast<float>(1e-05); auto tmp4 = float(tmp2 + tmp3); auto tmp5 = 1 / std::sqrt(tmp4); auto tmp7 = float(tmp5 * tmp6); auto tmp9 = decltype(tmp8)(-tmp8); auto tmp10 = float(tmp9 * tmp7); auto tmp12 = float(tmp10 + tmp11); in_out_ptr0[static_cast<int64_t>(x0)] = tmp7; in_out_ptr1[static_cast<int64_t>(x0)] = tmp12; } } } } #pragma omp parallel num_threads(4) { int tid = omp_get_thread_num(); { #pragma omp for for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(32L); x0+=static_cast<int64_t>(1L)) { #pragma GCC ivdep for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(2097152L); x1+=static_cast<int64_t>(1L)) { { { auto tmp0 = in_ptr0[static_cast<int64_t>(x1 + 2097152L*x0)]; auto tmp1 = in_out_ptr0[static_cast<int64_t>(x0)]; auto tmp3 = in_out_ptr1[static_cast<int64_t>(x0)]; auto tmp2 = float(tmp0 * tmp1); auto tmp4 = float(tmp2 + tmp3); out_ptr2[static_cast<int64_t>(x1 + 2097152L*x0)] = tmp4; } } } } } } } ''') async_compile.wait(globals()) del async_compile class Runner: def __init__(self, partitions): self.partitions = partitions def recursively_apply_fns(self, fns): new_callables = [] for fn, c in zip(fns, self.partitions): new_callables.append(fn(c)) self.partitions = new_callables def call(self, args): arg0_1, arg1_1, arg2_1 = args args.clear() assert_size_stride(arg0_1, (32, ), (1, )) assert_size_stride(arg1_1, (32, ), (1, )) assert_size_stride(arg2_1, (1, 32, 128, 128, 128), (67108864, 2097152, 16384, 128, 1)) buf0 = empty_strided_cpu((1, 32, 1, 1), (32, 1, 32, 32), torch.float32) buf1 = empty_strided_cpu((1, 32, 1, 1), (32, 1, 32, 32), torch.float32) buf3 = reinterpret_tensor(buf1, (1, 32, 1, 1), (32, 1, 1, 1), 0); del buf1 # reuse buf4 = reinterpret_tensor(buf0, (1, 32, 1, 1), (32, 1, 1, 1), 0); del buf0 # reuse buf5 = empty_strided_cpu((1, 32, 128, 128, 128), (67108864, 2097152, 16384, 128, 1), torch.float32) # [Provenance debug handles] cpp_fused_native_group_norm_0:1 cpp_fused_native_group_norm_0(buf3, buf4, arg2_1, arg0_1, arg1_1, buf5) del arg0_1 del arg1_1 del arg2_1 return (buf5, ) ``` Pull Request resolved: pytorch#162709 Approved by: https://github.com/CaoE, https://github.com/jansel
Stack from ghstack (oldest at bottom):
Summary:
Optimize scalar welford_reduce implementation, combining Welford algorithm with cascade sum to improve numerical stability. Specifically:
Example:
Take #141541 as an example:
logs
Generated code:
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @mlazos