forked from JamePeng/llama-cpp-python
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathllama_cache.py
More file actions
564 lines (462 loc) · 20.2 KB
/
llama_cache.py
File metadata and controls
564 lines (462 loc) · 20.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
from abc import ABC, abstractmethod
import array
from collections import OrderedDict
import ctypes
from dataclasses import dataclass
import diskcache
import hashlib
import sys
from typing import (
Any,
List,
Optional,
Sequence,
Tuple,
)
import llama_cpp.llama as llama_core
import llama_cpp.llama_cpp as llama_cpp_lib
from .llama_types import *
class BaseLlamaCache(ABC):
"""Base cache class for a llama.cpp model."""
def __init__(self, capacity_bytes: int = (2 << 30)):
self.capacity_bytes = capacity_bytes
@property
@abstractmethod
def cache_size(self) -> int:
raise NotImplementedError
def _find_longest_prefix_key(
self,
key: Tuple[int, ...],
) -> Optional[Tuple[int, ...]]:
pass
@abstractmethod
def __getitem__(self, key: Sequence[int]) -> "llama_core.LlamaState":
raise NotImplementedError
@abstractmethod
def __contains__(self, key: Sequence[int]) -> bool:
raise NotImplementedError
@abstractmethod
def __setitem__(
self, key: Sequence[int], value: "llama_core.LlamaState"
) -> None:
raise NotImplementedError
class LlamaDiskCache(BaseLlamaCache):
"""
Disk cache for a llama.cpp model.
Delegates LRU and size management natively to the SQLite-backed `diskcache` library.
"""
def __init__(
self, cache_dir: str = ".cache/llama_cache", capacity_bytes: int = (2 << 30), verbose: bool = False
):
super().__init__(capacity_bytes)
self.cache_dir = cache_dir
# Native SQLite size limit and LRU eviction
self.cache = diskcache.Cache(cache_dir, size_limit=capacity_bytes)
self.verbose = verbose
@property
def cache_size(self):
# Native O(1) volume calculation
return self.cache.volume() # type: ignore
def _find_longest_prefix_key(
self,
key: Tuple[int, ...],
) -> Optional[Tuple[int, ...]]:
# Early exit if cache is empty
if len(self.cache) == 0:
return None
min_len = 0
min_key: Optional[Tuple[int, ...]] = None
target_len = len(key)
for k in self.cache.iterkeys(): # type: ignore
prefix_len = llama_core.Llama.longest_token_prefix(k, key, self.verbose)
if prefix_len > min_len:
min_len = prefix_len
min_key = k # type: ignore
# Perfect match found, break to prevent full-table disk scan
if min_len == target_len:
break
return min_key
def __getitem__(self, key: Sequence[int]) -> "llama_core.LlamaState":
print("LlamaDiskCache.__getitem__: called", file=sys.stderr)
if len(self.cache) == 0:
raise KeyError("Cache is empty")
key = tuple(key)
_key = self._find_longest_prefix_key(key)
if _key is None:
raise KeyError("Key not found")
# Non-destructive read: automatically updates access time for LRU
value: "llama_core.LlamaState" = self.cache[_key] # type: ignore
return value
def __contains__(self, key: Sequence[int]) -> bool:
if len(self.cache) == 0:
return False
return self._find_longest_prefix_key(tuple(key)) is not None
def __setitem__(self, key: Sequence[int], value: "llama_core.LlamaState"):
print("LlamaDiskCache.__setitem__: called", file=sys.stderr)
# diskcache natively handles capacity check and eviction upon assignment
self.cache[tuple(key)] = value
class LlamaRAMCache(BaseLlamaCache):
"""
RAM cache for a llama.cpp model.
Maintains an LRU eviction policy with O(1) size tracking.
"""
def __init__(self, capacity_bytes: int = (2 << 30), verbose: bool = False):
super().__init__(capacity_bytes)
self.capacity_bytes = capacity_bytes
self.cache_state: OrderedDict[
Tuple[int, ...], "llama_core.LlamaState"
] = OrderedDict()
self._current_size = 0
self.verbose = verbose
@property
def cache_size(self):
return self._current_size
def _find_longest_prefix_key(
self,
key: Tuple[int, ...],
) -> Optional[Tuple[int, ...]]:
min_len = 0
min_key = None
keys = (
(k, llama_core.Llama.longest_token_prefix(k, key, self.verbose))
for k in self.cache_state.keys()
)
for k, prefix_len in keys:
if prefix_len > min_len:
min_len = prefix_len
min_key = k
return min_key
def __getitem__(self, key: Sequence[int]) -> "llama_core.LlamaState":
if not self.cache_state:
raise KeyError("Cache is empty")
key = tuple(key)
_key = self._find_longest_prefix_key(key)
if _key is None:
raise KeyError("Key not found")
value = self.cache_state[_key]
self.cache_state.move_to_end(_key)
return value
def __contains__(self, key: Sequence[int]) -> bool:
if not self.cache_state:
return False
return self._find_longest_prefix_key(tuple(key)) is not None
def __setitem__(self, key: Sequence[int], value: "llama_core.LlamaState"):
key = tuple(key)
if key in self.cache_state:
del self.cache_state[key]
self.cache_state[key] = value
self._current_size += value.llama_state_size
while self._current_size > self.capacity_bytes and len(self.cache_state) > 0:
_, popped_state = self.cache_state.popitem(last=False)
self._current_size -= popped_state.llama_state_size
self._current_size = max(0, self._current_size)
if len(self.cache_state) == 0:
self._current_size = 0
class TrieNode:
"""A node in the prefix tree (Trie)."""
def __init__(self):
# Child nodes: {token_id: TrieNode}
self.children: Dict[int, "TrieNode"] = {}
# Stores the LlamaState if this node marks the end of a cached sequence.
self.state: Optional["llama_core.LlamaState"] = None
class LlamaTrieCache(BaseLlamaCache):
"""
A Llama cache implementation using a Trie for O(K) prefix lookup
and an OrderedDict for O(1) LRU eviction.
- K = length of the query key (number of tokens)
- N = total number of items in the cache
This solves the O(N*K) lookup bottleneck of the linear scan cache.
"""
def __init__(self, capacity_bytes: int = (2 << 30)):
super().__init__(capacity_bytes)
self.root = TrieNode() # The root node of the Trie
self._current_size = 0 # O(1) tracking of cache size in bytes
# LRU Tracker:
# Key: Cached token sequence (Tuple[int, ...])
# Value: The *terminal* TrieNode for that key
self.lru_tracker: OrderedDict[
Tuple[int, ...], TrieNode
] = OrderedDict()
@property
def cache_size(self) -> int:
"""Returns the current total size of the cache in bytes (O(1))."""
return self._current_size
def _find_longest_prefix_node(
self, key: Tuple[int, ...]
) -> Tuple[Optional[TrieNode], Optional[Tuple[int, ...]]]:
"""
Finds the longest cached prefix for a given key in O(K) time.
Returns: (The matching TrieNode, The matching key)
"""
node = self.root
longest_prefix_node: Optional[TrieNode] = None
longest_prefix_key: Optional[Tuple[int, ...]] = None
current_prefix: List[int] = []
# Check if the empty prefix (root) is cached
if node.state is not None:
longest_prefix_node = node
longest_prefix_key = tuple(current_prefix)
for token in key:
if token not in node.children:
# Path ends, no further prefix matches
break
node = node.children[token]
current_prefix.append(token)
if node.state is not None:
# Found a valid, longer prefix; update our best match
longest_prefix_node = node
longest_prefix_key = tuple(current_prefix)
return longest_prefix_node, longest_prefix_key
def __getitem__(self, key: Sequence[int]) -> "llama_core.LlamaState":
"""
Retrieves the state for the longest matching prefix in O(K) time.
Updates the LRU status.
"""
key_tuple = tuple(key)
node, prefix_key = self._find_longest_prefix_node(key_tuple)
if node is None or node.state is None or prefix_key is None:
raise KeyError(f"Key prefix not found in cache for: {key_tuple}")
# Move the accessed key to the end (most recently used) in O(1)
self.lru_tracker.move_to_end(prefix_key)
return node.state
def __contains__(self, key: Sequence[int]) -> bool:
"""Checks if any prefix of the key is cached in O(K) time."""
node, _ = self._find_longest_prefix_node(tuple(key))
return node is not None
def _prune(self, key: Tuple[int, ...]):
"""
(Helper) Removes a key and its state from the Trie.
Also removes empty parent nodes (branch pruning).
"""
path: List[Tuple[TrieNode, int]] = [] # Stores (parent_node, token)
node = self.root
# 1. Find the node and record the path
for token in key:
if token not in node.children:
return # Key not found
path.append((node, token))
node = node.children[token]
# 2. Remove the state
if node.state is None:
return # Node has no state
self._current_size -= node.state.llama_state_size
node.state = None
# 3. Prune empty parent nodes backward
for parent, token in reversed(path):
child = parent.children[token]
# If the child node is now empty (no children, no state), delete it
if not child.children and child.state is None:
del parent.children[token]
else:
# Node is still in use, stop pruning
break
def __setitem__(self, key: Sequence[int], value: "llama_core.LlamaState"):
"""
Adds a (key, state) pair to the cache in O(K) time.
Handles LRU updates and eviction.
"""
key_tuple = tuple(key)
# 1. Find or create nodes for the key (O(K))
node = self.root
for token in key_tuple:
node = node.children.setdefault(token, TrieNode())
# 2. Check if updating an existing item
if node.state is not None:
self._current_size -= node.state.llama_state_size
# 3. Set new state and update O(1) size
node.state = value
self._current_size += value.llama_state_size
# 4. Update LRU tracker (O(1))
if key_tuple in self.lru_tracker:
self.lru_tracker.move_to_end(key_tuple)
else:
self.lru_tracker[key_tuple] = node
# 5. Eviction logic
while self._current_size > self.capacity_bytes and self.lru_tracker:
# Get the least recently used item in O(1)
evicted_key, _ = self.lru_tracker.popitem(last=False)
# Remove the evicted item from the Trie
self._prune(evicted_key)
# Alias for backwards compatibility
LlamaCache = LlamaTrieCache
@dataclass
class HybridCheckpoint:
"""Represents a single snapshot of the RNN/Hybrid model's hidden state."""
pos: int # The token position (cursor) where this snapshot was taken
data: bytes # The raw binary RNN state data
hash_val: str # SHA-256 hash of the token prefix to ensure exact sequence matching
size: int # Size of the state data in bytes
seq_id: int # Sequence ID this checkpoint belongs to
class HybridCheckpointCache(BaseLlamaCache):
"""
Manager for RNN state snapshots (Checkpoints) tailored for Hybrid/Recurrent models.
Provides rollback capabilities for models that cannot physically truncate KV cache.
"""
def __init__(self, ctx: llama_cpp_lib.llama_context_p, max_checkpoints: int = 16, verbose: bool = False):
if ctx is None:
raise ValueError("HybridCheckpointCache(__init__): Failed to create HybridCheckpointCache with model context")
self._ctx = ctx
self.max_checkpoints = max_checkpoints
self.checkpoints: list[HybridCheckpoint] = []
self._current_size = 0
# Cache C-type API function pointers for performance
self._get_size_ext = llama_cpp_lib.llama_state_seq_get_size_ext
self._get_data_ext = llama_cpp_lib.llama_state_seq_get_data_ext
self._set_data_ext = llama_cpp_lib.llama_state_seq_set_data_ext
self._flag_partial = llama_cpp_lib.LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY
self.verbose = verbose
if self.max_checkpoints <= 0:
if self.verbose:
import sys
print("HybridCheckpointCache(__init__): Cache is DISABLED (max_checkpoints <= 0). "
"Rollback capabilities are turned off. This is optimal for single-turn workflows.",
file=sys.stderr)
@property
def cache_size(self) -> int:
"""Returns the total memory used by all stored checkpoints in bytes."""
return self._current_size
def clear(self):
"""Clears all stored checkpoints and resets memory tracking."""
if not self.checkpoints:
# Empty Checkpoint: Return immediately, no need to clear.
return
self.checkpoints.clear()
self._current_size = 0
if self.verbose:
print("HybridCheckpointCache: cleared")
def close(self):
self.checkpoints = None
self._ctx = None
self._get_size_ext = None
self._get_data_ext = None
self._set_data_ext = None
def __del__(self) -> None:
self.close()
# Helper tools
def _hash_prefix(self, tokens: List[int], length: int) -> str:
"""
Computes a SHA-256 hash for a sequence of tokens up to the specified length.
This ensures that checkpoints are only restored for the EXACT same conversation history.
"""
if length <= 0:
return "empty"
tokens_size = len(tokens)
if length > tokens_size:
length = tokens_size
data = array.array('i', tokens[:length]).tobytes()
return hashlib.sha256(data).hexdigest()[:32]
def find_best_checkpoint(self, tokens: List[int], seq_id: int = 0) -> Optional[HybridCheckpoint]:
"""
Finds the longest valid checkpoint that perfectly matches the provided token prefix.
Returns None if no matching checkpoint is found.
"""
# Empty Checkpoint: Instant return, no hash calculation needed.
if self.max_checkpoints <= 0 or len(self.checkpoints) == 0:
return None
best_cp = None
best_pos = -1
for cp in self.checkpoints:
if cp.seq_id != seq_id or cp.pos > len(tokens):
# Skip if sequence ID mismatches or checkpoint is longer than the current prompt
continue
# Verify cryptographic integrity of the prompt history
if self._hash_prefix(tokens, cp.pos) == cp.hash_val:
if cp.pos > best_pos:
# Keep the checkpoint with the longest matching prefix (highest pos)
best_pos = cp.pos
best_cp = cp
return best_cp
def save_checkpoint(
self,
current_pos: int,
tokens: List[int],
seq_id: int = 0
) -> bool:
"""
Extracts the RNN hidden state from the C++ backend and saves it as a checkpoint.
Manages eviction (FIFO) if the maximum number of checkpoints is exceeded.
"""
# 0. Early Exit / Feature Toggle
# If the user disables checkpoints (max_checkpoints <= 0), we immediately return.
# This absolutely critical bypass prevents massive (e.g., 150MB+) synchronous
# VRAM-to-RAM copies over the PCIe bus, eliminating multi-second delays at the
# end of generation for single-turn workflows.
# This is more friendly to the single-call ComfyUI ecosystem. :)
if self.max_checkpoints <= 0:
if self.verbose:
print("HybridCheckpointCache(save_checkpoint): Cache is DISABLED (max_checkpoints <= 0). "
"Operating in single-turn conversation mode. Skipping state extraction to optimize generation latency.",
file=sys.stderr)
return False
flags = self._flag_partial
# 1. Query the required buffer size from the underlying C++ context
size = self._get_size_ext(self._ctx, seq_id, flags)
if size == 0:
if self.verbose:
print("HybridCheckpointCache(save_checkpoint): size=0, skip")
return False
# 2. Allocate buffer and extract raw state data
buffer = (ctypes.c_uint8 * size)()
n_written = self._get_data_ext(self._ctx, buffer, size, seq_id, flags)
if n_written != size:
if self.verbose:
print(f"HybridCheckpointCache(save_checkpoint): get failed {n_written}/{size}")
return False
# Note: This deep copy isolates the state from subsequent C++ backend mutations
data_bytes = bytes(buffer[:n_written])
hash_val = self._hash_prefix(tokens, current_pos)
# 3. Store the newly extracted checkpoint
self.checkpoints.append(HybridCheckpoint(
pos=current_pos,
data=data_bytes,
hash_val=hash_val,
size=n_written,
seq_id=seq_id)
)
self._current_size += n_written
# 4. Enforce capacity limits (FIFO eviction)
while len(self.checkpoints) > self.max_checkpoints:
if not self.checkpoints:
break
old_cp = self.checkpoints.pop(0)
self._current_size -= old_cp.size
if self.verbose:
print(f"HybridCheckpointCache(save_checkpoint): evicted pos={old_cp.pos}")
if self.verbose:
print(f"HybridCheckpointCache(save_checkpoint): Saved checkpoint at pos {current_pos} ({size / 1024 / 1024:.2f} MiB) "
f"total={len(self.checkpoints)} used={self._current_size / 1024 / 1024:.2f} MiB",
file=sys.stderr)
return True
def restore_checkpoint(self, cp: HybridCheckpoint, seq_id: int = 0) -> bool:
"""
Injects a previously saved RNN state checkpoint back into the C++ backend memory.
"""
# 1. Verify sequence ID matches to prevent cross-sequence contamination
if cp.seq_id != seq_id:
if self.verbose:
print(f"HybridCheckpointCache(restore_checkpoint): [Error] Sequence ID mismatch: checkpoint has {cp.seq_id}, requested {seq_id}", file=sys.stderr)
return False
flags = self._flag_partial
# 2. Verify the underlying C++ context still expects the exact same state size.
# This prevents buffer overflows if the backend context was unexpectedly altered or reallocated.
current_size = self._get_size_ext(self._ctx, seq_id, flags)
if current_size != cp.size:
if self.verbose:
print(f"HybridCheckpointCache(restore_checkpoint): [Warning] State size mismatch before restore: expected {cp.size}, got {current_size} -> possible invalidation")
return False
# 3. Copy data back to a ctypes buffer and push to the C++ backend
buffer = (ctypes.c_uint8 * cp.size).from_buffer_copy(cp.data)
ret = self._set_data_ext(
self._ctx, buffer, cp.size, seq_id, flags
)
success = (ret == cp.size)
if self.verbose:
print(f"HybridCheckpointCache(restore_checkpoint): restore {'OK' if success else 'FAIL'} pos={cp.pos}")
return success
# Disable BaseLlamaCache Dictionary Interfaces
def __getitem__(self, key):
raise NotImplementedError("HybridCheckpointCache: pls use save_checkpoint or restore_checkpoint method")
def __setitem__(self, key, value):
raise NotImplementedError("HybridCheckpointCache: pls use save_checkpoint or restore_checkpoint method")
def __contains__(self, key):
raise NotImplementedError("HybridCheckpointCache: pls use save_checkpoint or restore_checkpoint method")