forked from docarray/docarray
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathimage_bytes.py
More file actions
145 lines (109 loc) · 4.29 KB
/
image_bytes.py
File metadata and controls
145 lines (109 loc) · 4.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
from io import BytesIO
from typing import TYPE_CHECKING, Any, Optional, Tuple, Type, TypeVar
import numpy as np
from pydantic import parse_obj_as
from pydantic.validators import bytes_validator
from docarray.typing.abstract_type import AbstractType
from docarray.typing.proto_register import _register_proto
from docarray.typing.tensor.image.image_ndarray import ImageNdArray
from docarray.utils._internal.misc import import_library
if TYPE_CHECKING:
from PIL import Image as PILImage
from pydantic.fields import BaseConfig, ModelField
from docarray.proto import NodeProto
T = TypeVar('T', bound='ImageBytes')
@_register_proto(proto_type_name='image_bytes')
class ImageBytes(bytes, AbstractType):
"""
Bytes that store an image and that can be load into an image tensor
"""
@classmethod
def validate(
cls: Type[T],
value: Any,
field: 'ModelField',
config: 'BaseConfig',
) -> T:
value = bytes_validator(value)
return cls(value)
@classmethod
def from_protobuf(cls: Type[T], pb_msg: T) -> T:
return parse_obj_as(cls, pb_msg)
def _to_node_protobuf(self: T) -> 'NodeProto':
from docarray.proto import NodeProto
return NodeProto(blob=self, type=self._proto_type_name)
def load_pil(
self,
) -> 'PILImage.Image':
"""
Load the image from the bytes into a `PIL.Image.Image` instance
---
```python
from pydantic import parse_obj_as
from docarray import BaseDoc
from docarray.typing import ImageUrl
img_url = "https://upload.wikimedia.org/wikipedia/commons/8/80/Dag_Sebastian_Ahlander_at_G%C3%B6teborg_Book_Fair_2012b.jpg"
img_url = parse_obj_as(ImageUrl, img_url)
img = img_url.load_pil()
from PIL.Image import Image
assert isinstance(img, Image)
```
---
:return: a Pillow image
"""
PIL = import_library('PIL', raise_error=True) # noqa: F841
from PIL import Image as PILImage
return PILImage.open(BytesIO(self))
def load(
self,
width: Optional[int] = None,
height: Optional[int] = None,
axis_layout: Tuple[str, str, str] = ('H', 'W', 'C'),
) -> ImageNdArray:
"""
Load the image from the [`ImageBytes`][docarray.typing.ImageBytes] into an
[`ImageNdArray`][docarray.typing.ImageNdArray].
---
```python
from docarray import BaseDoc
from docarray.typing import ImageNdArray, ImageUrl
class MyDoc(BaseDoc):
img_url: ImageUrl
doc = MyDoc(
img_url="https://upload.wikimedia.org/wikipedia/commons/8/80/"
"Dag_Sebastian_Ahlander_at_G%C3%B6teborg_Book_Fair_2012b.jpg"
)
img_tensor = doc.img_url.load()
assert isinstance(img_tensor, ImageNdArray)
img_tensor = doc.img_url.load(height=224, width=224)
assert img_tensor.shape == (224, 224, 3)
layout = ('C', 'W', 'H')
img_tensor = doc.img_url.load(height=100, width=200, axis_layout=layout)
assert img_tensor.shape == (3, 200, 100)
```
---
:param width: width of the image tensor.
:param height: height of the image tensor.
:param axis_layout: ordering of the different image axes.
'H' = height, 'W' = width, 'C' = color channel
:return: [`ImageNdArray`][docarray.typing.ImageNdArray] representing the image as RGB values
"""
raw_img = self.load_pil()
if width or height:
new_width = width or raw_img.width
new_height = height or raw_img.height
raw_img = raw_img.resize((new_width, new_height))
try:
tensor = np.array(raw_img.convert('RGB'))
except Exception:
tensor = np.array(raw_img)
img = self._move_channel_axis(tensor, axis_layout=axis_layout)
return parse_obj_as(ImageNdArray, img)
@staticmethod
def _move_channel_axis(
tensor: np.ndarray, axis_layout: Tuple[str, str, str] = ('H', 'W', 'C')
) -> np.ndarray:
"""Moves channel axis around."""
channel_to_offset = {'H': 0, 'W': 1, 'C': 2}
permutation = tuple(channel_to_offset[axis] for axis in axis_layout)
return np.transpose(tensor, permutation)