@@ -47,7 +47,6 @@ def test_llama_cpp_tokenization():
4747@pytest .fixture
4848def mock_llama (monkeypatch ):
4949 def setup_mock (llama : llama_cpp .Llama , output_text : str ):
50- llama .reset ()
5150 n_vocab = llama .n_vocab ()
5251 output_tokens = llama .tokenize (
5352 output_text .encode ("utf-8" ), add_bos = True , special = True
@@ -59,28 +58,41 @@ def mock_decode(ctx: llama_cpp.llama_context_p, batch: llama_cpp.llama_batch):
5958 nonlocal n
6059 nonlocal last_n_tokens
6160 # Test some basic invariants of this mocking technique
62- assert ctx == llama ._ctx .ctx
63- assert llama .n_tokens == n
64- assert batch .n_tokens > 0
65- n += batch .n_tokens
61+ assert ctx == llama ._ctx .ctx , "context does not match mock_llama"
62+ assert batch .n_tokens > 0 , "no tokens in batch"
63+ assert all (
64+ batch .n_seq_id [i ] == 1 for i in range (batch .n_tokens )
65+ ), "n_seq >1 not supported by mock_llama"
66+ assert all (
67+ batch .seq_id [i ][0 ] == 0 for i in range (batch .n_tokens )
68+ ), "n_seq >1 not supported by mock_llama"
69+ assert batch .logits [
70+ batch .n_tokens - 1
71+ ], "logits not allocated for last token"
72+ # Update the mock context state
73+ n = max (batch .pos [i ] for i in range (batch .n_tokens )) + 1
6674 last_n_tokens = batch .n_tokens
6775 return 0
6876
6977 def mock_get_logits (* args , ** kwargs ):
70- nonlocal last_n_tokens
71- size = n_vocab * last_n_tokens
72- return (llama_cpp .c_float * size )()
73-
74- def mock_sample (* args , ** kwargs ):
7578 nonlocal n
76- if n < len (output_tokens ):
77- return output_tokens [n ]
78- else :
79- return llama .token_eos ()
79+ nonlocal last_n_tokens
80+ assert n > 0 , "mock_llama_decode not called"
81+ assert last_n_tokens > 0 , "mock_llama_decode not called"
82+ logits = (llama_cpp .c_float * (last_n_tokens * n_vocab ))(- 100.0 )
83+ for logits_idx , output_idx in enumerate (
84+ range (n - last_n_tokens + 1 , n + 1 )
85+ ):
86+ if output_idx < len (output_tokens ):
87+ logits [
88+ logits_idx * last_n_tokens + output_tokens [output_idx ]
89+ ] = 100.0
90+ else :
91+ logits [logits_idx * last_n_tokens + llama .token_eos ()] = 100.0
92+ return logits
8093
8194 monkeypatch .setattr ("llama_cpp.llama_cpp.llama_decode" , mock_decode )
8295 monkeypatch .setattr ("llama_cpp.llama_cpp.llama_get_logits" , mock_get_logits )
83- monkeypatch .setattr ("llama_cpp.llama_cpp.llama_sample_token" , mock_sample )
8496
8597 return setup_mock
8698
0 commit comments