-
Notifications
You must be signed in to change notification settings - Fork 263
Expand file tree
/
Copy pathinclude.py
More file actions
200 lines (148 loc) · 5.05 KB
/
include.py
File metadata and controls
200 lines (148 loc) · 5.05 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
import os
import sys
import threading
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Callable, Dict, Literal, Optional, Tuple
import replicate
from .exceptions import ModelError
from .model import Model
from .prediction import Prediction
from .run import _has_output_iterator_array_type
from .version import Version
__all__ = ["get_run_state", "get_run_token", "include", "run_state", "run_token"]
_run_state: Optional[Literal["load", "setup", "run"]] = None
_run_token: Optional[str] = None
_state_stack = []
_token_stack = []
_state_lock = threading.RLock()
_token_lock = threading.RLock()
def get_run_state() -> Optional[Literal["load", "setup", "run"]]:
"""
Get the current run state.
"""
return _run_state
def get_run_token() -> Optional[str]:
"""
Get the current API token.
"""
return _run_token
@contextmanager
def run_state(state: Literal["load", "setup", "run"]) -> Any:
"""
Context manager for setting the current run state.
"""
global _run_state
if threading.current_thread() is not threading.main_thread():
raise RuntimeError("Only the main thread can modify run state")
with _state_lock:
_state_stack.append(_run_state)
_run_state = state
try:
yield
finally:
with _state_lock:
_run_state = _state_stack.pop()
@contextmanager
def run_token(token: str) -> Any:
"""
Context manager for setting the current API token.
"""
global _run_token
if threading.current_thread() is not threading.main_thread():
raise RuntimeError("Only the main thread can modify API token")
with _token_lock:
_token_stack.append(_run_token)
_run_token = token
try:
yield
finally:
with _token_lock:
_run_token = _token_stack.pop()
def _find_api_token() -> str:
token = os.environ.get("REPLICATE_API_TOKEN")
if token:
print("Using Replicate API token from environment", file=sys.stderr)
return token
current_token = get_run_token()
if current_token is None:
raise ValueError("No run token found")
return current_token
@dataclass
class Run:
"""
Represents a running prediction with access to its version.
"""
prediction: Prediction
version: Version
def wait(self) -> Any:
"""
Wait for the prediction to complete and return its output.
"""
self.prediction.wait()
if self.prediction.status == "failed":
raise ModelError(self.prediction)
if _has_output_iterator_array_type(self.version):
return "".join(self.prediction.output)
return self.prediction.output
def logs(self) -> Optional[str]:
"""
Fetch and return the logs from the prediction.
"""
self.prediction.reload()
return self.prediction.logs
@dataclass
class Function:
"""
A wrapper for a Replicate model that can be called as a function.
"""
function_ref: str
def _client(self) -> replicate.Client:
return replicate.Client(api_token=_find_api_token())
def _split_function_ref(self) -> Tuple[str, str, Optional[str]]:
owner, name = self.function_ref.split("/")
name, version = name.split(":") if ":" in name else (name, None)
return owner, name, version
def _model(self) -> Model:
client = self._client()
model_owner, model_name, _ = self._split_function_ref()
return client.models.get(f"{model_owner}/{model_name}")
def _version(self) -> Version:
client = self._client()
model_owner, model_name, model_version = self._split_function_ref()
model = client.models.get(f"{model_owner}/{model_name}")
version = (
model.versions.get(model_version) if model_version else model.latest_version
)
return version
def __call__(self, **inputs: Dict[str, Any]) -> Any:
run = self.start(**inputs)
return run.wait()
def start(self, **inputs: Dict[str, Any]) -> Run:
"""
Start a prediction with the specified inputs.
"""
version = self._version()
prediction = self._client().predictions.create(version=version, input=inputs)
print(f"Running {self.function_ref}: https://replicate.com/p/{prediction.id}")
return Run(prediction, version)
@property
def default_example(self) -> Optional[Prediction]:
"""
Get the default example for this model.
"""
return self._model().default_example
@property
def openapi_schema(self) -> dict[Any, Any]:
"""
Get the OpenAPI schema for this model version.
"""
return self._version().openapi_schema
def include(function_ref: str) -> Callable[..., Any]:
"""
Include a Replicate model as a function.
This function can only be called at the top level.
"""
if get_run_state() != "load":
raise RuntimeError("You may only call replicate.include at the top level.")
return Function(function_ref)