Skip to content

Commit a61fd3b

Browse files
committed
Add example based on stripped down version of main.cpp from llama.cpp
1 parent da9b71c commit a61fd3b

File tree

1 file changed

+85
-0
lines changed

1 file changed

+85
-0
lines changed

examples/llama_cpp_main.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import llama_cpp
2+
3+
import multiprocessing
4+
5+
import llama_cpp
6+
7+
N_THREADS = multiprocessing.cpu_count()
8+
9+
prompt = b"\n\n### Instruction:\nWhat is the capital of France?\n\n### Response:\n"
10+
11+
lparams = llama_cpp.llama_context_default_params()
12+
ctx = llama_cpp.llama_init_from_file(b"models/ggml-alpaca-7b-q4.bin", lparams)
13+
14+
# determine the required inference memory per token:
15+
tmp = [0, 1, 2, 3]
16+
llama_cpp.llama_eval(ctx, (llama_cpp.c_int * len(tmp))(*tmp), len(tmp), 0, N_THREADS)
17+
18+
n_past = 0
19+
20+
prompt = b" " + prompt
21+
22+
embd_inp = (llama_cpp.llama_token * (len(prompt) + 1))()
23+
n_of_tok = llama_cpp.llama_tokenize(ctx, prompt, embd_inp, len(embd_inp), True)
24+
embd_inp = embd_inp[:n_of_tok]
25+
26+
n_ctx = llama_cpp.llama_n_ctx(ctx)
27+
28+
n_predict = 20
29+
n_predict = min(n_predict, n_ctx - len(embd_inp))
30+
31+
input_consumed = 0
32+
input_noecho = False
33+
34+
remaining_tokens = n_predict
35+
36+
embd = []
37+
last_n_size = 64
38+
last_n_tokens = [0] * last_n_size
39+
n_batch = 24
40+
41+
while remaining_tokens > 0:
42+
if len(embd) > 0:
43+
llama_cpp.llama_eval(
44+
ctx, (llama_cpp.c_int * len(embd))(*embd), len(embd), n_past, N_THREADS
45+
)
46+
47+
n_past += len(embd)
48+
embd = []
49+
if len(embd_inp) <= input_consumed:
50+
id = llama_cpp.llama_sample_top_p_top_k(
51+
ctx,
52+
(llama_cpp.c_int * len(last_n_tokens))(*last_n_tokens),
53+
len(last_n_tokens),
54+
40,
55+
0.8,
56+
0.2,
57+
1.0 / 0.85,
58+
)
59+
last_n_tokens = last_n_tokens[1:] + [id]
60+
embd.append(id)
61+
input_noecho = False
62+
remaining_tokens -= 1
63+
else:
64+
while len(embd_inp) > input_consumed:
65+
embd.append(embd_inp[input_consumed])
66+
last_n_tokens = last_n_tokens[1:] + [embd_inp[input_consumed]]
67+
input_consumed += 1
68+
if len(embd) >= n_batch:
69+
break
70+
if not input_noecho:
71+
for id in embd:
72+
print(
73+
llama_cpp.llama_token_to_str(ctx, id).decode("utf-8"),
74+
end="",
75+
flush=True,
76+
)
77+
78+
if len(embd) > 0 and embd[-1] == llama_cpp.llama_token_eos():
79+
break
80+
81+
print()
82+
83+
llama_cpp.llama_print_timings(ctx)
84+
85+
llama_cpp.llama_free(ctx)

0 commit comments

Comments
 (0)