@@ -324,7 +324,7 @@ def __init__(
324324 self ._candidates = candidates
325325 self ._token_nl = Llama .token_nl ()
326326 self ._token_eos = Llama .token_eos ()
327- self ._candidates_data_id = np .arange (self ._n_vocab , dtype = np .intc ) # type: ignore
327+ self ._candidates_data_id = np .arange (self ._n_vocab , dtype = np .intc ) # type: ignore
328328 self ._candidates_data_p = np .zeros (self ._n_vocab , dtype = np .single )
329329
330330 self .n_tokens = 0
@@ -445,8 +445,12 @@ def eval(self, tokens: Sequence[int]):
445445 # Save logits
446446 rows = n_tokens if self .params .logits_all else 1
447447 cols = self ._n_vocab
448- offset = 0 if self .params .logits_all else n_tokens - 1 # NOTE: Only save the last token logits if logits_all is False
449- self .scores [self .n_tokens + offset : self .n_tokens + n_tokens , :].reshape (- 1 )[:] = llama_cpp .llama_get_logits (self .ctx )[:rows * cols ]
448+ offset = (
449+ 0 if self .params .logits_all else n_tokens - 1
450+ ) # NOTE: Only save the last token logits if logits_all is False
451+ self .scores [self .n_tokens + offset : self .n_tokens + n_tokens , :].reshape (
452+ - 1
453+ )[:] = llama_cpp .llama_get_logits (self .ctx )[: rows * cols ]
450454 # Update n_tokens
451455 self .n_tokens += n_tokens
452456
@@ -491,7 +495,7 @@ def _sample(
491495 candidates_data = self ._candidates_data
492496 candidates_data ["id" ][:] = self ._candidates_data_id # type: ignore
493497 candidates_data ["logit" ][:] = logits
494- candidates_data ["p" ][:] = self ._candidates_data_p # type: ignore
498+ candidates_data ["p" ][:] = self ._candidates_data_p # type: ignore
495499 candidates .data = candidates_data .ctypes .data_as (llama_cpp .llama_token_data_p )
496500 candidates .sorted = llama_cpp .c_bool (False )
497501 candidates .size = llama_cpp .c_size_t (n_vocab )
@@ -537,7 +541,7 @@ def _sample(
537541 mirostat_mu = llama_cpp .c_float (2.0 * mirostat_tau .value )
538542 llama_cpp .llama_sample_temperature (
539543 ctx = self .ctx ,
540- candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
544+ candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
541545 temp = temp ,
542546 )
543547 return llama_cpp .llama_sample_token_mirostat_v2 (
0 commit comments