@@ -67,6 +67,12 @@ def _load_shared_library(lib_base_name):
6767_lib = _load_shared_library (_lib_base_name )
6868
6969# C types
70+ LLAMA_FILE_VERSION = ctypes .c_int (1 )
71+ LLAMA_FILE_MAGIC = b"ggjt"
72+ LLAMA_FILE_MAGIC_UNVERSIONED = b"ggml"
73+ LLAMA_SESSION_MAGIC = b"ggsn"
74+ LLAMA_SESSION_VERSION = ctypes .c_int (0 )
75+
7076llama_context_p = c_void_p
7177
7278
@@ -77,13 +83,24 @@ def _load_shared_library(lib_base_name):
7783class llama_token_data (Structure ):
7884 _fields_ = [
7985 ("id" , llama_token ), # token id
86+ ("logit" , c_float ), # log-odds of the token
8087 ("p" , c_float ), # probability of the token
81- ("plog" , c_float ), # log probability of the token
8288 ]
8389
8490
8591llama_token_data_p = POINTER (llama_token_data )
8692
93+
94+ class llama_token_data_array (Structure ):
95+ _fields_ = [
96+ ("data" , llama_token_data_p ),
97+ ("size" , c_size_t ),
98+ ("sorted" , c_bool ),
99+ ]
100+
101+
102+ llama_token_data_array_p = POINTER (llama_token_data_array )
103+
87104llama_progress_callback = ctypes .CFUNCTYPE (None , c_float , c_void_p )
88105
89106
@@ -118,7 +135,7 @@ class llama_context_params(Structure):
118135 4
119136) # tok_embeddings.weight and output.weight are F16
120137LLAMA_FTYPE_MOSTLY_Q4_2 = ctypes .c_int (5 ) # except 1d tensors
121- LLAMA_FTYPE_MOSTYL_Q4_3 = ctypes .c_int (6 ) # except 1d tensors
138+ # LLAMA_FTYPE_MOSTYL_Q4_3 = ctypes.c_int(6) # except 1d tensors
122139LLAMA_FTYPE_MOSTYL_Q8_0 = ctypes .c_int (7 ) # except 1d tensors
123140LLAMA_FTYPE_MOSTYL_Q5_0 = ctypes .c_int (8 ) # except 1d tensors
124141LLAMA_FTYPE_MOSTYL_Q5_1 = ctypes .c_int (9 ) # except 1d tensors
@@ -401,31 +418,214 @@ def llama_token_eos() -> llama_token:
401418_lib .llama_token_eos .restype = llama_token
402419
403420
404- # TODO: improve the last_n_tokens interface ?
405- def llama_sample_top_p_top_k (
421+ def llama_token_nl () -> llama_token :
422+ return _lib .llama_token_nl ()
423+
424+
425+ _lib .llama_token_nl .argtypes = []
426+ _lib .llama_token_nl .restype = llama_token
427+
428+
429+ # Sampling functions
430+ def llama_sample_repetition_penalty (
431+ ctx : llama_context_p ,
432+ candidates ,
433+ last_tokens_data ,
434+ last_tokens_size : c_int ,
435+ penalty : c_float ,
436+ ) -> llama_token :
437+ return _lib .llama_sample_repetition_penalty (
438+ ctx , candidates , last_tokens_data , last_tokens_size , penalty
439+ )
440+
441+
442+ _lib .llama_sample_repetition_penalty .argtypes = [
443+ llama_context_p ,
444+ llama_token_data_array_p ,
445+ llama_token_p ,
446+ c_int ,
447+ c_float ,
448+ ]
449+ _lib .llama_sample_repetition_penalty .restype = llama_token
450+
451+
452+ # LLAMA_API void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, llama_token * last_tokens, size_t last_tokens_size, float alpha_frequency, float alpha_presence);
453+ def llama_sample_frequency_and_presence_penalties (
406454 ctx : llama_context_p ,
407- last_n_tokens_data , # type: Array[llama_token]
408- last_n_tokens_size : c_int ,
409- top_k : c_int ,
410- top_p : c_float ,
411- temp : c_float ,
412- repeat_penalty : c_float ,
455+ candidates ,
456+ last_tokens_data ,
457+ last_tokens_size : c_int ,
458+ alpha_frequency : c_float ,
459+ alpha_presence : c_float ,
413460) -> llama_token :
414- return _lib .llama_sample_top_p_top_k (
415- ctx , last_n_tokens_data , last_n_tokens_size , top_k , top_p , temp , repeat_penalty
461+ return _lib .llama_sample_frequency_and_presence_penalties (
462+ ctx ,
463+ candidates ,
464+ last_tokens_data ,
465+ last_tokens_size ,
466+ alpha_frequency ,
467+ alpha_presence ,
416468 )
417469
418470
419- _lib .llama_sample_top_p_top_k .argtypes = [
471+ _lib .llama_sample_frequency_and_presence_penalties .argtypes = [
420472 llama_context_p ,
473+ llama_token_data_array_p ,
421474 llama_token_p ,
422475 c_int ,
476+ c_float ,
477+ c_float ,
478+ ]
479+ _lib .llama_sample_frequency_and_presence_penalties .restype = llama_token
480+
481+
482+ # LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates);
483+ def llama_sample_softmax (ctx : llama_context_p , candidates ) -> llama_token :
484+ return _lib .llama_sample_softmax (ctx , candidates )
485+
486+
487+ _lib .llama_sample_softmax .argtypes = [
488+ llama_context_p ,
489+ llama_token_data_array_p ,
490+ ]
491+ _lib .llama_sample_softmax .restype = llama_token
492+
493+
494+ # LLAMA_API void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int k, size_t min_keep = 1);
495+ def llama_sample_top_k (
496+ ctx : llama_context_p , candidates , k : c_int , min_keep : c_int
497+ ) -> llama_token :
498+ return _lib .llama_sample_top_k (ctx , candidates , k , min_keep )
499+
500+
501+ _lib .llama_sample_top_k .argtypes = [
502+ llama_context_p ,
503+ llama_token_data_array_p ,
504+ c_int ,
505+ c_int ,
506+ ]
507+
508+
509+ # LLAMA_API void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep = 1);
510+ def llama_sample_top_p (
511+ ctx : llama_context_p , candidates , p : c_float , min_keep : c_int
512+ ) -> llama_token :
513+ return _lib .llama_sample_top_p (ctx , candidates , p , min_keep )
514+
515+
516+ _lib .llama_sample_top_p .argtypes = [
517+ llama_context_p ,
518+ llama_token_data_array_p ,
519+ c_float ,
520+ c_int ,
521+ ]
522+ _lib .llama_sample_top_p .restype = llama_token
523+
524+
525+ # LLAMA_API void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep = 1);
526+ def llama_sample_tail_free (
527+ ctx : llama_context_p , candidates , z : c_float , min_keep : c_int
528+ ) -> llama_token :
529+ return _lib .llama_sample_tail_free (ctx , candidates , z , min_keep )
530+
531+
532+ _lib .llama_sample_tail_free .argtypes = [
533+ llama_context_p ,
534+ llama_token_data_array_p ,
535+ c_float ,
536+ c_int ,
537+ ]
538+ _lib .llama_sample_tail_free .restype = llama_token
539+
540+
541+ # LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep = 1);
542+ def llama_sample_typical (
543+ ctx : llama_context_p , candidates , p : c_float , min_keep : c_int
544+ ) -> llama_token :
545+ return _lib .llama_sample_typical (ctx , candidates , p , min_keep )
546+
547+
548+ _lib .llama_sample_typical .argtypes = [
549+ llama_context_p ,
550+ llama_token_data_array_p ,
551+ c_float ,
423552 c_int ,
553+ ]
554+ _lib .llama_sample_typical .restype = llama_token
555+
556+
557+ # LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp);
558+ def llama_sample_temperature (
559+ ctx : llama_context_p , candidates , temp : c_float
560+ ) -> llama_token :
561+ return _lib .llama_sample_temperature (ctx , candidates , temp )
562+
563+
564+ _lib .llama_sample_temperature .argtypes = [
565+ llama_context_p ,
566+ llama_token_data_array_p ,
424567 c_float ,
568+ ]
569+ _lib .llama_sample_temperature .restype = llama_token
570+
571+
572+ # LLAMA_API llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu);
573+ def llama_sample_token_mirostat (
574+ ctx : llama_context_p , candidates , tau : c_float , eta : c_float , m : c_int , mu
575+ ) -> llama_token :
576+ return _lib .llama_sample_token_mirostat (ctx , candidates , tau , eta , m , mu )
577+
578+
579+ _lib .llama_sample_token_mirostat .argtypes = [
580+ llama_context_p ,
581+ llama_token_data_array_p ,
582+ c_float ,
583+ c_float ,
584+ c_int ,
585+ POINTER (c_float ),
586+ ]
587+ _lib .llama_sample_token_mirostat .restype = llama_token
588+
589+
590+ # LLAMA_API llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu);
591+ def llama_sample_token_mirostat_v2 (
592+ ctx : llama_context_p , candidates , tau : c_float , eta : c_float , mu
593+ ) -> llama_token :
594+ return _lib .llama_sample_token_mirostat_v2 (ctx , candidates , tau , eta , mu )
595+
596+
597+ _lib .llama_sample_token_mirostat_v2 .argtypes = [
598+ llama_context_p ,
599+ llama_token_data_array_p ,
425600 c_float ,
426601 c_float ,
602+ POINTER (c_float ),
603+ ]
604+ _lib .llama_sample_token_mirostat_v2 .restype = llama_token
605+
606+
607+ # LLAMA_API llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates);
608+ def llama_sample_token_greedy (ctx : llama_context_p , candidates ) -> llama_token :
609+ return _lib .llama_sample_token_greedy (ctx , candidates )
610+
611+
612+ _lib .llama_sample_token_greedy .argtypes = [
613+ llama_context_p ,
614+ llama_token_data_array_p ,
615+ ]
616+ _lib .llama_sample_token_greedy .restype = llama_token
617+
618+
619+ # LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates);
620+ def llama_sample_token (ctx : llama_context_p , candidates ) -> llama_token :
621+ return _lib .llama_sample_token (ctx , candidates )
622+
623+
624+ _lib .llama_sample_token .argtypes = [
625+ llama_context_p ,
626+ llama_token_data_array_p ,
427627]
428- _lib .llama_sample_top_p_top_k .restype = llama_token
628+ _lib .llama_sample_token .restype = llama_token
429629
430630
431631# Performance information
0 commit comments