Skip to content

Add Adam optimizer.#212

Merged
copybara-service[bot] merged 1 commit intogoogle:devfrom
szabadka:adam2
Jun 7, 2024
Merged

Add Adam optimizer.#212
copybara-service[bot] merged 1 commit intogoogle:devfrom
szabadka:adam2

Conversation

@szabadka
Copy link
Collaborator

@szabadka szabadka commented Jun 6, 2024

Drive-by: Fix compilation errors and tests for backprop functions.

@szabadka szabadka requested a review from jan-wassenberg June 6, 2024 16:29
gemma/gemma.cc Outdated
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't the default ctor (in the header) enough to get us a null impl_?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was a compilation error because the unique ptr did not know that size of the forward declared class, so I had to add something to the cc file.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see. You are right, all the special functions indeed need to be in the .cc after the definition of Impl. I was confused by the body of the ctor here - that is unnecessary, we can just write GemmaTokenizer::GemmaTokenizer() = default. I'll add to my TODO.

gemma/gemma.cc Outdated
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm, I'm not sure we actually want to encourage f32 weights. For training, wouldn't it make sense to use bf16 weights? Those are considered compressed, though we'd have to build with -DGEMMA_WEIGHT_T=hwy::bfloat16_t. That should be faster, and let us only have a single function here. And maybe we could even get rid of kWeightsAreCompressed?

Note that 'compressed' can also mean f32. It would be nice to get rid of the duplicated non-compressed code now that we have the separate compress_weights binary.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed these for now, since training works if I change kWeightsAreCompressed to false.

gemma/gemma.cc Outdated
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to avoid this duplication. It seems that you want to use f32 (if not bf16, see above) weights for the backprop. What prevents us from doing that with kWeightsAreCompressed=true, and setting GEMMA_WEIGHT_T to float?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for now I reverted this part, but still keeping kWeightsAreCompressed

Drive-by: Fix compilation errors and tests for backprop functions.
@jan-wassenberg jan-wassenberg added the copybara-import Trigger Copybara for merging pull requests label Jun 7, 2024
@copybara-service copybara-service bot merged commit f7ac709 into google:dev Jun 7, 2024
@szabadka szabadka deleted the adam2 branch June 7, 2024 09:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

copybara-import Trigger Copybara for merging pull requests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants