forked from abetlen/llama-cpp-python
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathllama_jinja_format.py
More file actions
138 lines (107 loc) · 4.08 KB
/
llama_jinja_format.py
File metadata and controls
138 lines (107 loc) · 4.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
"""
llama_cpp/llama_jinja_format.py
"""
import dataclasses
from typing import Any, Callable, Dict, List, Optional, Protocol, Union
import jinja2
from jinja2 import Template
# NOTE: We sacrifice readability for usability.
# It will fail to work as expected if we attempt to format it in a readable way.
llama2_template = """{% for message in messages %}{% if message['role'] == 'user' %}[INST] {{ message['content'] }} [/INST]\n{% elif message['role'] == 'assistant' %}{{ message['content'] }}\n{% elif message['role'] == 'system' %}<<SYS>> {{ message['content'] }} <</SYS>>\n{% endif %}{% endfor %}"""
class MetaSingleton(type):
"""
Metaclass for implementing the Singleton pattern.
"""
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(MetaSingleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]
class Singleton(object, metaclass=MetaSingleton):
"""
Base class for implementing the Singleton pattern.
"""
def __init__(self):
super(Singleton, self).__init__()
@dataclasses.dataclass
class ChatFormatterResponse:
prompt: str
stop: Optional[Union[str, List[str]]] = None
# Base Chat Formatter Protocol
class ChatFormatterInterface(Protocol):
def __init__(self, template: Optional[object] = None):
...
def __call__(
self,
messages: List[Dict[str, str]],
**kwargs,
) -> ChatFormatterResponse:
...
@property
def template(self) -> str:
...
class AutoChatFormatter(ChatFormatterInterface):
def __init__(
self,
template: Optional[str] = None,
template_class: Optional[Template] = None,
):
if template is not None:
self._template = template
else:
self._template = llama2_template # default template
self._environment = jinja2.Environment(
loader=jinja2.BaseLoader(),
trim_blocks=True,
lstrip_blocks=True,
).from_string(
self._template,
template_class=template_class,
)
def __call__(
self,
messages: List[Dict[str, str]],
**kwargs: Any,
) -> ChatFormatterResponse:
formatted_sequence = self._environment.render(messages=messages, **kwargs)
return ChatFormatterResponse(prompt=formatted_sequence)
@property
def template(self) -> str:
return self._template
class FormatterNotFoundException(Exception):
pass
class ChatFormatterFactory(Singleton):
_chat_formatters: Dict[str, Callable[[], ChatFormatterInterface]] = {}
def register_formatter(
self,
name: str,
formatter_callable: Callable[[], ChatFormatterInterface],
overwrite=False,
):
if not overwrite and name in self._chat_formatters:
raise ValueError(
f"Formatter with name '{name}' is already registered. Use `overwrite=True` to overwrite it."
)
self._chat_formatters[name] = formatter_callable
def unregister_formatter(self, name: str):
if name in self._chat_formatters:
del self._chat_formatters[name]
else:
raise ValueError(f"No formatter registered under the name '{name}'.")
def get_formatter_by_name(self, name: str) -> ChatFormatterInterface:
try:
formatter_callable = self._chat_formatters[name]
return formatter_callable()
except KeyError:
raise FormatterNotFoundException(
f"Invalid chat format: {name} (valid formats: {list(self._chat_formatters.keys())})"
)
# Define a chat format class
class Llama2Formatter(AutoChatFormatter):
def __init__(self):
super().__init__(llama2_template)
# With the Singleton pattern applied, regardless of where or how many times
# ChatFormatterFactory() is called, it will always return the same instance
# of the factory, ensuring that the factory's state is consistent throughout
# the application.
ChatFormatterFactory().register_formatter("llama-2", Llama2Formatter)