-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path__init__.py
More file actions
61 lines (56 loc) · 1.71 KB
/
__init__.py
File metadata and controls
61 lines (56 loc) · 1.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from typing import Optional
import numpy as np
from onnx import TensorProto
from onnx.numpy_helper import from_array as onnx_from_array
try:
from onnx.reference.ops.op_cast import (
bfloat16,
float8e4m3fn,
float8e4m3fnuz,
float8e5m2,
float8e5m2fnuz,
)
except ImportError:
bfloat16 = None
try:
from onnx.reference.op_run import to_array_extended
except ImportError:
from onnx.numpy_helper import to_array as to_array_extended
from .evaluator import ExtendedReferenceEvaluator
from .evaluator_yield import (
DistanceExecution,
ResultExecution,
ResultType,
YieldEvaluator,
compare_onnx_execution,
)
def from_array_extended(tensor: np.array, name: Optional[str] = None) -> TensorProto:
"""
Converts an array into a TensorProto.
:param tensor: numpy array
:param name: name
:return: TensorProto
"""
if bfloat16 is None:
return onnx_from_array(tensor, name)
dt = tensor.dtype
if dt == float8e4m3fn and dt.descr[0][0] == "e4m3fn":
to = TensorProto.FLOAT8E4M3FN
dt_to = np.uint8
elif dt == float8e4m3fnuz and dt.descr[0][0] == "e4m3fnuz":
to = TensorProto.FLOAT8E4M3FNUZ
dt_to = np.uint8
elif dt == float8e5m2 and dt.descr[0][0] == "e5m2":
to = TensorProto.FLOAT8E5M2
dt_to = np.uint8
elif dt == float8e5m2fnuz and dt.descr[0][0] == "e5m2fnuz":
to = TensorProto.FLOAT8E5M2FNUZ
dt_to = np.uint8
elif dt == bfloat16 and dt.descr[0][0] == "bfloat16":
to = TensorProto.BFLOAT16
dt_to = np.uint16
else:
return onnx_from_array(tensor, name)
t = onnx_from_array(tensor.astype(dt_to), name)
t.data_type = to
return t