Skip to content

Fix logits_to_logprobs for 2-D and 3-D logits#1002

Merged
abetlen merged 3 commits intoabetlen:mainfrom
kddubey:kddubey/fix-logits-to-logprobs
Dec 16, 2023
Merged

Fix logits_to_logprobs for 2-D and 3-D logits#1002
abetlen merged 3 commits intoabetlen:mainfrom
kddubey:kddubey/fix-logits-to-logprobs

Conversation

@kddubey
Copy link
Copy Markdown
Contributor

@kddubey kddubey commented Dec 12, 2023

The implementation in main (from this PR) only works for 1-D logits. It silently fails for 2-D or 3-D logits. The implementation in this PR works out-of-the-box for 1-D, 2-D, and 3-D logits. (3-D is possible in the future w/ batch inference and logits_all=True.) This feature might be useful b/c there are some places in the code where we can save time by vectorizing / not converting data to lists. I'll do that in a future PR.

The minimal and sufficient fix is to set axis=-1 in the np.max call, and set keepdims=True in the np.sum call. I decided to instead go with a more robust implementation. It's almost copy-pasted from scipy.special.log_softmax. I decided against adding scipy as a required dependency b/c it's not lightweight—the latest version is ~37 MB.

How has this been tested?

Script

  1. Install the new test dependency, scipy, which contains a correct implementation

    python -m pip install scipy
  2. Checkout main

    git checkout main
  3. Run this script in main to verify that the current implementation is silently wrong for 2-D logits

    from __future__ import annotations
    
    import numpy as np
    from scipy.special import log_softmax
    
    from llama_cpp import Llama
    
    atol = 1e-3  # intentionally set to be loose when testing the impl in main
    size = (2, 3)
    logits: list = (
        (-np.random.uniform(low=0, high=60, size=size)).astype(np.single).tolist()
    )
    
    logprobs = Llama.logits_to_logprobs(logits)
    logprobs_correct = log_softmax(logits, axis=-1)
    assert np.allclose(logprobs, logprobs_correct, atol=atol)
  4. Checkout this branch

    git checkout kddubey/fix-logits-to-logprobs
  5. Run the same script with atol=1e-6. No error should be raised.

New unit tests

pytest tests/test_llama.py -k test_logits_to_logprobs

test = [
"pytest>=7.4.0",
"httpx>=0.24.1",
"scipy>=1.10",
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the oldest version compatible with numpy>=1.20.0

source: https://docs.scipy.org/doc/scipy/dev/toolchain.html#numpy

@kddubey kddubey changed the title Fix logits_to_logprobs Fix logits_to_logprobs for 2-D and 3-D logits Dec 12, 2023
@abetlen
Copy link
Copy Markdown
Owner

abetlen commented Dec 16, 2023

@kddubey thank you, yes that's a good idea wrt vectorizing the logits -> logprobs calculation

@abetlen abetlen merged commit 5a89446 into abetlen:main Dec 16, 2023
@kddubey kddubey deleted the kddubey/fix-logits-to-logprobs branch December 17, 2023 00:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants