-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathcompile_wrapper.py
More file actions
208 lines (115 loc) · 4.53 KB
/
compile_wrapper.py
File metadata and controls
208 lines (115 loc) · 4.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
from __future__ import annotations
import torch
from inspect import signature
from operator import attrgetter
from types import NoneType
from typing import Any, Callable, Dict, Annotated, List, TypeVar, Union, Tuple
from msgspec import Struct
__all__ = ['Dynamic', 'BaseKernel']
T = TypeVar('T')
NAMESPACE: str = "teal"
TENSOR: str = "Tensor"
NONE: str = "None"
DOT_OP: str = "."
TYPE_SEP: str = " | "
COMMA_SPACE: str = ", "
TYPE_MAPPINGS: Dict[str, str] = (
{
'dtype': 'ScalarType',
'device': 'Device',
}
)
class Dynamic:
Int = Annotated[int, 'SymInt']
Num = Annotated[int | float, 'Scalar']
def check_origin(annotation: Any, target: Any) -> bool:
return hasattr(annotation, '__origin__') and annotation.__origin__ is target
def unpack_optional(annotation: Union[T, NoneType] | T) -> Tuple[T, bool]:
if check_origin(annotation, Union):
target, maybe_none, *maybe_more = annotation.__args__
if maybe_more or maybe_none is not NoneType:
raise TypeError(f'unsupported multi-union type: {annotation}')
else:
return target, True
return annotation, False
def unpack_list(annotation: List[T] | T) -> Tuple[T, bool]:
if check_origin(annotation, list):
return annotation.__args__[0], True
return annotation, False
def resolve_type(annotation: Any) -> str:
target, is_optional = unpack_optional(annotation)
target, is_iterable = unpack_list(target)
if hasattr(target, '__metadata__'):
resolved = target.__metadata__[0]
elif hasattr(target, '__name__'):
if (str_value := target.__name__).endswith(TENSOR):
resolved = TENSOR
else:
resolved = TYPE_MAPPINGS.get(str_value, str_value)
else:
raise TypeError(f'unable to infer type from given {target}')
if is_iterable:
resolved = f'{resolved}[]'
if is_optional:
resolved = f'{resolved}?'
return resolved
def resolve_return(annotation: Any) -> str:
target = annotation
if check_origin(annotation, tuple):
target = annotation.__args__
if isinstance(target, tuple):
return f'({COMMA_SPACE.join(map(resolve_type, target))})'
else:
return resolve_type(target)
class BaseKernel(Struct):
"""
Base helper dataclass for wrapping custom kernels for registration with `torch.library`.
Note(s):
• The underlying `schema` is automatically inferred from the type annotation of the
`forward` method.
• Should this fail to recognize non-Tensor(s), feel free to override the `schematize`
class-method.
"""
# Library name and corresponding target device.
name: str
target: str
# The auto-generated signature.
schema: str
@classmethod
def initialize(cls, name: str, target: str, **kwargs) -> BaseKernel:
return cls(name, target, cls.schematize())
@classmethod
def schematize(cls) -> str:
params = dict((forward_signature := signature(cls.forward, eval_str=True)).parameters)
_ = params.pop('self')
arguments = (f"{resolve_type(p.annotation)} {name}" for name, p in params.items())
out = resolve_return(forward_signature.return_annotation)
return f"({', '.join(arguments)}) -> {out}"
@property
def is_registered(self) -> bool:
return hasattr(getattr(torch.ops, NAMESPACE), self.name)
def operator(self, compiled: bool = False) -> Callable:
if compiled:
self.register()
return attrgetter(f'{NAMESPACE}.{self.name}')(torch.ops)
else:
return self.forward
def meta(self, *args, **kwargs) -> Any:
raise NotImplementedError(
f'abstract implementation `meta` in {self.__class__.__name__} required for '
f'registration'
)
def forward(self, *args, **kwargs) -> Any:
raise NotImplementedError(
f'concrete implementation `forward` in {self.__class__.__name__} required for '
f'registration'
)
def register(self) -> None:
if not self.is_registered:
qualname = f'{NAMESPACE}::{self.name}'
# Define the library namespace and signature.
torch.library.define(qualname, self.schema)
# Register the abstract implementation i.e. with meta tensors.
torch.library.impl_abstract(qualname, func=self.meta)
# Register the concrete implementation.
torch.library.impl(qualname, self.target, func=self.forward)