MQA Implementation for 2B models#114
Conversation
austinvhuang
left a comment
There was a problem hiding this comment.
Thanks very much, MQA is one of the most important low-hanging fruit to implement right now! Looks pretty good overall, have a look at the comment about avoiding branching.
Tagging @pculliton to check the model exporting + vocab size change and @jan-wassenberg on any perf suggestions.
jan-wassenberg
left a comment
There was a problem hiding this comment.
Nice, thank you :) Some small suggestions:
austinvhuang
left a comment
There was a problem hiding this comment.
This LGTM, if the performance looks good/better (I'm curious how much) and generation looks correct + @jan-wassenberg LGTMs can probably move forward with merging to dev.
|
I tested the weights converted from gemma_pytorch (2b-it and 7b-it) and the generation looks fine. |
jan-wassenberg
left a comment
There was a problem hiding this comment.
Very nice use of lambdas! Thanks for making the change.
This PR implements "Multi-Query Attention" for the 2B models and modifies vocabulary size to be the same as gemma_pytorch (mentioned in #103). It works fine with weights converted from gemma_pytorch but will lead to the original gemma.cpp weights are unusable.
It needs more testing, and I'll use it to test the fine-tuned weights.