|
| 1 | +import llama_cpp |
| 2 | + |
| 3 | +import multiprocessing |
| 4 | + |
| 5 | +import llama_cpp |
| 6 | + |
| 7 | +N_THREADS = multiprocessing.cpu_count() |
| 8 | + |
| 9 | +prompt = b"\n\n### Instruction:\nWhat is the capital of France?\n\n### Response:\n" |
| 10 | + |
| 11 | +lparams = llama_cpp.llama_context_default_params() |
| 12 | +ctx = llama_cpp.llama_init_from_file(b"models/ggml-alpaca-7b-q4.bin", lparams) |
| 13 | + |
| 14 | +# determine the required inference memory per token: |
| 15 | +tmp = [0, 1, 2, 3] |
| 16 | +llama_cpp.llama_eval(ctx, (llama_cpp.c_int * len(tmp))(*tmp), len(tmp), 0, N_THREADS) |
| 17 | + |
| 18 | +n_past = 0 |
| 19 | + |
| 20 | +prompt = b" " + prompt |
| 21 | + |
| 22 | +embd_inp = (llama_cpp.llama_token * (len(prompt) + 1))() |
| 23 | +n_of_tok = llama_cpp.llama_tokenize(ctx, prompt, embd_inp, len(embd_inp), True) |
| 24 | +embd_inp = embd_inp[:n_of_tok] |
| 25 | + |
| 26 | +n_ctx = llama_cpp.llama_n_ctx(ctx) |
| 27 | + |
| 28 | +n_predict = 20 |
| 29 | +n_predict = min(n_predict, n_ctx - len(embd_inp)) |
| 30 | + |
| 31 | +input_consumed = 0 |
| 32 | +input_noecho = False |
| 33 | + |
| 34 | +remaining_tokens = n_predict |
| 35 | + |
| 36 | +embd = [] |
| 37 | +last_n_size = 64 |
| 38 | +last_n_tokens = [0] * last_n_size |
| 39 | +n_batch = 24 |
| 40 | + |
| 41 | +while remaining_tokens > 0: |
| 42 | + if len(embd) > 0: |
| 43 | + llama_cpp.llama_eval( |
| 44 | + ctx, (llama_cpp.c_int * len(embd))(*embd), len(embd), n_past, N_THREADS |
| 45 | + ) |
| 46 | + |
| 47 | + n_past += len(embd) |
| 48 | + embd = [] |
| 49 | + if len(embd_inp) <= input_consumed: |
| 50 | + id = llama_cpp.llama_sample_top_p_top_k( |
| 51 | + ctx, |
| 52 | + (llama_cpp.c_int * len(last_n_tokens))(*last_n_tokens), |
| 53 | + len(last_n_tokens), |
| 54 | + 40, |
| 55 | + 0.8, |
| 56 | + 0.2, |
| 57 | + 1.0 / 0.85, |
| 58 | + ) |
| 59 | + last_n_tokens = last_n_tokens[1:] + [id] |
| 60 | + embd.append(id) |
| 61 | + input_noecho = False |
| 62 | + remaining_tokens -= 1 |
| 63 | + else: |
| 64 | + while len(embd_inp) > input_consumed: |
| 65 | + embd.append(embd_inp[input_consumed]) |
| 66 | + last_n_tokens = last_n_tokens[1:] + [embd_inp[input_consumed]] |
| 67 | + input_consumed += 1 |
| 68 | + if len(embd) >= n_batch: |
| 69 | + break |
| 70 | + if not input_noecho: |
| 71 | + for id in embd: |
| 72 | + print( |
| 73 | + llama_cpp.llama_token_to_str(ctx, id).decode("utf-8"), |
| 74 | + end="", |
| 75 | + flush=True, |
| 76 | + ) |
| 77 | + |
| 78 | + if len(embd) > 0 and embd[-1] == llama_cpp.llama_token_eos(): |
| 79 | + break |
| 80 | + |
| 81 | +print() |
| 82 | + |
| 83 | +llama_cpp.llama_print_timings(ctx) |
| 84 | + |
| 85 | +llama_cpp.llama_free(ctx) |
0 commit comments