Factor out deinterleaving of bf16 vectors for MatVecs.#166
Factor out deinterleaving of bf16 vectors for MatVecs.#166copybara-service[bot] merged 9 commits intogoogle:devfrom
Conversation
jan-wassenberg
left a comment
There was a problem hiding this comment.
Nice, great to see this coming together :) Thanks for sending the PR. Some sugestions:
gemma/ops.h
Outdated
|
|
||
| const hn::ScalableTag<float> df; | ||
|
|
||
| const auto vec_dequant = hwy::AllocateAligned<float>(kInner); |
There was a problem hiding this comment.
Allocation can be quite slow, let's move this into gemma.cc's Activations. That would require plumbing through an extra tmp arg here, and the std::array storage should probably be the largest per-call size * max number of threads (say 128 or 256). Would you prefer if I made this change?
There was a problem hiding this comment.
@jan-wassenberg Sure. Thanks for the help!
|
@jan-wassenberg Thank you for reviewing! I'll branch on native BF16 support and clean up those near-duplicate MatVecAdd implementations, then turn off this PR's draft bit. |
|
@jan-wassenberg One more question: what's the best way to check that the target doesn't have a native bf16 product/dot product support (e.g., AVX512_BF16)? You previously pointed me at a highway PR, but it looks like Copybara scrubbed the branch when the PR was dropped. |
We can check |
gemma/ops.h
Outdated
| // vector to even-odd layout. | ||
| template <bool kAdd, size_t kOuter, size_t kInner, typename ArrayT, | ||
| typename VecT, typename AddT, | ||
| std::enable_if_t< |
There was a problem hiding this comment.
Consider replacing with HWY_IF_SAME2(VecT, float, hwy::bfloat16_t).
Also inline ProjQ and ProjKV lambdas, add missing includes/deps for ops_test. PiperOrigin-RevId: 629460608
|
@jan-wassenberg Done. Native bf16 checks added. Additionally, 59ebecc fixes a bug I introduced in 6a78a23. That commit affected overload resolution such that the specialization was never called. That's now fixed by moving the bulk of MatVecAdd into detail::MatVecAddInner and between even-odd and linear layouts inside a constexpr. Using a constexpr ensures that it's all downstream of MatVecAdd's type inference. |
|
I see even_odd storage is merged to dev. I'll merge. |
jan-wassenberg
left a comment
There was a problem hiding this comment.
Nice, looks good to me, thanks for updating! Can you give it a quick sanity check also with sfp weights (e.g. those prefixed 1.1 on Kaggle) to make sure that also still works?
Already did. Works great. |
|
Thanks for confirming :D |
|
Internal CI caught some unused vars: Please fix :) |
samkaufman
left a comment
There was a problem hiding this comment.
Oops. Hopefully that sorts it.
Remove extra Dot() overload MatVecAdd always adds, use MatVecT<kAdd> if conditional. Remove ununsed MatVecAddLoop and MatVecLoop No longer tsan-verify even_odd PiperOrigin-RevId: 631377279
Disable it for float32 because there is not enough benefit. PiperOrigin-RevId: 631788326
This specializes bf16-f32 and bf16-bf16 vector-matrix multiplications to first convert bf16 vectors into f32 buffers of vector-length strips of even- and odd-indexed values.
The 2B, bf16 model running on my Zen 1 machine sees ~10% throughput improvements to single-threaded prefill, single-threaded generation, and multi-threaded prefill, but only a marginal improvement to multi-threaded generation throughput.
This PR does not implement support for SFP.
TODOs: