Make saved state more compact on-disk#1296
Merged
abetlen merged 2 commits intoabetlen:mainfrom Apr 17, 2024
Merged
Conversation
- Only store up to `n_tokens` logits instead of full `(n_ctx, n_vocab)`
sized array.
- Difference between ~350MB and ~1500MB for example prompt with ~300
tokens (makes sense lol)
- Auto-formatting changes
Owner
|
Hey @tc-wolf thanks for the contribution. First do you mind reverting the formatting changes, thank you. Next, yes we should only store the |
Contributor
Author
|
I'll work on putting together a better reprex that I can share, but I was able to reload with: def reload_from_low_level_state_file(model: Llama, state_file: str):
from llama_cpp import llama_load_session_file
n_ctx = model.n_ctx()
session_tokens = (llama_cpp.llama_token * n_ctx)()
n_tok = llama_cpp.ctypes.c_size_t()
# Load context from file
status = llama_load_session_file(
model.ctx,
state_file.encode("utf-8"),
session_tokens,
n_ctx,
llama_cpp.ctypes.byref(n_tok),
)
if status != 1:
raise ValueError(f"Failed to load context from '{state_file}'")
n_tok_int = int(n_tok.value)
print(f"n_tok: {n_tok_int}")
assert (
session_tokens[:n_tok_int] == state.input_ids[:n_tok_int].tolist()
), "Tokens should match"
# Set various objects like in `load_state`
n_vocab = model.n_vocab()
logits_from_low_level = np.ctypeslib.as_array(
model._ctx.get_logits(), (n_tok_int, n_vocab)
)
model.scores[:n_tok_int, :] = logits_from_low_level.copy()
model.scores[n_tok_int:, :] = 0.0
for i in range(0, n_tok_int):
model.input_ids[i] = session_tokens[i]
model.n_tokens = n_tok_intand then predict with the same prompt, and got the same logits for predicted tokens: # Loads and resets to state created by doing `save_state` after `model.eval(prompt_tokens)`
model = get_nekomata7b()
model.load_state(state)
reload_from_low_level_state_file(model, "nekomata-7b-ctx.bin")
out3 = model.create_chat_completion(
[
{"role": "system", "content": DEFAULT_SYSTEM_PROMPT},
{"role": "user", "content": prompt},
],
temperature=0.0,
max_tokens=300,
seed=1432
)
out3
pred_logits3 = model.scores[n_tok_prompt:model.n_tokens, :]
assert (pred_logits == pred_logits3).all(), "All prediction logits should match"This is without any logit bias and with no |
xhedit
pushed a commit
to xhedit/llama-cpp-conv
that referenced
this pull request
Apr 30, 2024
* State load/save changes
- Only store up to `n_tokens` logits instead of full `(n_ctx, n_vocab)`
sized array.
- Difference between ~350MB and ~1500MB for example prompt with ~300
tokens (makes sense lol)
- Auto-formatting changes
* Back out formatting changes
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Was looking at making the llama context more compact so that could store a cache for a large # of documents.
Changes
n_tokenslogits instead of full(n_ctx, n_vocab)sized array.Q: Do we need to store
Llama.scoreswhen logits are already saved on thellama_contextstruct? This ~doubles the size ofLlamaState, but I tried reloading directly usingget_logitsand the two were not the same (not sure why, think it could be repetition penalty applied toLlama.scores? There wasn't a logits_bias orlogits_processorbeing applied AFAICT).