Skip to content

Make saved state more compact on-disk#1296

Merged
abetlen merged 2 commits intoabetlen:mainfrom
tc-wolf:compact_save_state
Apr 17, 2024
Merged

Make saved state more compact on-disk#1296
abetlen merged 2 commits intoabetlen:mainfrom
tc-wolf:compact_save_state

Conversation

@tc-wolf
Copy link
Copy Markdown
Contributor

@tc-wolf tc-wolf commented Mar 21, 2024

Was looking at making the llama context more compact so that could store a cache for a large # of documents.

Changes

  • Only store up to n_tokens logits instead of full (n_ctx, n_vocab) sized array.
    • Difference between ~350MB and ~1500MB for example prompt with ~300 tokens vs. context size of 4096.
  • Auto-formatting changes
    • Can back these out if not desired

Q: Do we need to store Llama.scores when logits are already saved on the llama_context struct? This ~doubles the size of LlamaState, but I tried reloading directly using get_logits and the two were not the same (not sure why, think it could be repetition penalty applied to Llama.scores? There wasn't a logits_bias or logits_processor being applied AFAICT).

- Only store up to `n_tokens` logits instead of full `(n_ctx, n_vocab)`
  sized array.
  - Difference between ~350MB and ~1500MB for example prompt with ~300
    tokens (makes sense lol)
- Auto-formatting changes
@abetlen
Copy link
Copy Markdown
Owner

abetlen commented Mar 26, 2024

Hey @tc-wolf thanks for the contribution. First do you mind reverting the formatting changes, thank you.

Next, yes we should only store the (n_tokens, n_vocab) sized logits array, however afaik it's untrue that this is stored in the llama_context. The llama_context only stores the logits for the last batch that was processed.

@tc-wolf
Copy link
Copy Markdown
Contributor Author

tc-wolf commented Mar 26, 2024

I'll work on putting together a better reprex that I can share, but I was able to reload with:

def reload_from_low_level_state_file(model: Llama, state_file: str):
    from llama_cpp import llama_load_session_file

    n_ctx = model.n_ctx()
    session_tokens = (llama_cpp.llama_token * n_ctx)()
    n_tok = llama_cpp.ctypes.c_size_t()

    # Load context from file
    status = llama_load_session_file(
        model.ctx,
        state_file.encode("utf-8"),
        session_tokens,
        n_ctx,
        llama_cpp.ctypes.byref(n_tok),
    )

    if status != 1:
        raise ValueError(f"Failed to load context from '{state_file}'")

    n_tok_int = int(n_tok.value)
    print(f"n_tok: {n_tok_int}")

    assert (
        session_tokens[:n_tok_int] == state.input_ids[:n_tok_int].tolist()
    ), "Tokens should match"

    # Set various objects like in `load_state`
    n_vocab = model.n_vocab()
    logits_from_low_level = np.ctypeslib.as_array(
        model._ctx.get_logits(), (n_tok_int, n_vocab)
    )

    model.scores[:n_tok_int, :] = logits_from_low_level.copy()
    model.scores[n_tok_int:, :] = 0.0

    for i in range(0, n_tok_int):
        model.input_ids[i] = session_tokens[i]

    model.n_tokens = n_tok_int

and then predict with the same prompt, and got the same logits for predicted tokens:

# Loads and resets to state created by doing `save_state` after `model.eval(prompt_tokens)`
model = get_nekomata7b()
model.load_state(state)

reload_from_low_level_state_file(model, "nekomata-7b-ctx.bin") 
out3 = model.create_chat_completion(
    [
        {"role": "system", "content": DEFAULT_SYSTEM_PROMPT},
        {"role": "user", "content": prompt},
    ],
    temperature=0.0,
    max_tokens=300,
    seed=1432
)
out3
pred_logits3 = model.scores[n_tok_prompt:model.n_tokens, :] 
assert (pred_logits == pred_logits3).all(), "All prediction logits should match"

This is without any logit bias and with no logit_processors.

@abetlen abetlen merged commit 4924455 into abetlen:main Apr 17, 2024
@tc-wolf tc-wolf deleted the compact_save_state branch April 17, 2024 19:26
xhedit pushed a commit to xhedit/llama-cpp-conv that referenced this pull request Apr 30, 2024
* State load/save changes

- Only store up to `n_tokens` logits instead of full `(n_ctx, n_vocab)`
  sized array.
  - Difference between ~350MB and ~1500MB for example prompt with ~300
    tokens (makes sense lol)
- Auto-formatting changes

* Back out formatting changes
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants