@@ -275,14 +275,23 @@ def __init__(self, params: GptParams) -> None:
275275 presence_penalty = { self .params .presence_penalty } ,\
276276 frequency_penalty = { self .params .frequency_penalty } ,\
277277 top_k = { self .params .top_k } ,\
278- tfs_z = { self .params .tfs_z } ,\
278+ top_n_sigma = { self .params .top_n_sigma } ,\
279279 top_p = { self .params .top_p } ,\
280280 typical_p = { self .params .typical_p } ,\
281281 temp = { self .params .temp } ,\
282282 mirostat = { self .params .mirostat } ,\
283283 mirostat_lr = { self .params .mirostat_eta } ,\
284284 mirostat_ent = { self .params .mirostat_tau } ,\
285285
286+ xtc_threshold = { self .params .xtc_threshold } ,\
287+ xtc_probability = { self .params .xtc_probability } ,\
288+
289+ dry_multiplier = { self .params .dry_multiplier } ,\
290+ dry_base = { self .params .dry_base } ,\
291+ dry_allowed_length = { self .params .dry_allowed_length } ,\
292+ dry_penalty_last_n = { self .params .dry_penalty_last_n } ,\
293+ dry_seq_breakers = { self .params .dry_seq_breakers } ,\
294+
286295generate: n_ctx = { self .n_ctx } ,\
287296 n_batch = { self .params .n_batch } ,\
288297 n_predict = { self .params .n_predict } ,\
@@ -454,7 +463,7 @@ def generate(self):
454463 _arr = (llama_cpp .llama_token * last_n_repeat )(
455464 * self .last_n_tokens [len (self .last_n_tokens ) - last_n_repeat :]
456465 )
457- llama_cpp .llama_sample_repetition_penalties (
466+ llama_cpp .llama_sampler_init_penalties (
458467 ctx = self .ctx ,
459468 candidates = candidates_p ,
460469 last_tokens_data = _arr ,
@@ -474,15 +483,15 @@ def generate(self):
474483
475484 if self .params .temp <= 0 :
476485 # Greedy sampling
477- id = llama_cpp .llama_sample_token_greedy (self .ctx , candidates_p )
486+ id = llama_cpp .llama_sampler_init_greedy (self .ctx , candidates_p )
478487 else :
479488 if self .params .mirostat == 1 :
480489 mirostat_mu = 2.0 * self .params .mirostat_tau
481490 mirostat_m = 100
482- llama_cpp .llama_sample_temperature (
491+ llama_cpp .llama_sampler_init_temp (
483492 self .ctx , candidates_p , llama_cpp .c_float (self .params .temp )
484493 )
485- id = llama_cpp .llama_sample_token_mirostat (
494+ id = llama_cpp .llama_sampler_init_mirostat (
486495 self .ctx ,
487496 candidates_p ,
488497 llama_cpp .c_float (self .params .mirostat_tau ),
@@ -495,7 +504,7 @@ def generate(self):
495504 llama_cpp .llama_sample_temperature (
496505 self .ctx , candidates_p , llama_cpp .c_float (self .params .temp )
497506 )
498- id = llama_cpp .llama_sample_token_mirostat_v2 (
507+ id = llama_cpp .llama_sampler_init_mirostat_v2 (
499508 self .ctx ,
500509 candidates_p ,
501510 llama_cpp .c_float (self .params .mirostat_tau ),
@@ -504,31 +513,31 @@ def generate(self):
504513 )
505514 else :
506515 # Temperature sampling
507- llama_cpp .llama_sample_top_k (
516+ llama_cpp .llama_sampler_init_top_k (
508517 self .ctx ,
509518 candidates_p ,
510519 top_k ,
511520 min_keep = llama_cpp .c_size_t (1 ),
512521 )
513- llama_cpp .llama_sample_tail_free (
522+ llama_cpp .llama_sampler_init_top_n_sigma (
514523 self .ctx ,
515524 candidates_p ,
516- llama_cpp .c_float (self .params .tfs_z ),
525+ llama_cpp .c_float (self .params .top_n_sigma ),
517526 min_keep = llama_cpp .c_size_t (1 ),
518527 )
519- llama_cpp .llama_sample_typical (
528+ llama_cpp .llama_sampler_init_typical (
520529 self .ctx ,
521530 candidates_p ,
522531 llama_cpp .c_float (self .params .typical_p ),
523532 min_keep = llama_cpp .c_size_t (1 ),
524533 )
525- llama_cpp .llama_sample_top_p (
534+ llama_cpp .llama_sampler_init_top_p (
526535 self .ctx ,
527536 candidates_p ,
528537 llama_cpp .c_float (self .params .top_p ),
529538 min_keep = llama_cpp .c_size_t (1 ),
530539 )
531- llama_cpp .llama_sample_temperature (
540+ llama_cpp .llama_sampler_init_temp (
532541 self .ctx , candidates_p , llama_cpp .c_float (self .params .temp )
533542 )
534543 id = llama_cpp .llama_sample_token (self .ctx , candidates_p )
0 commit comments