forked from huggingface/diffusers
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvideo_processor.py
More file actions
180 lines (152 loc) · 8.16 KB
/
video_processor.py
File metadata and controls
180 lines (152 loc) · 8.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
import numpy as np
import PIL
import torch
import torch.nn.functional as F
from .image_processor import VaeImageProcessor, is_valid_image, is_valid_image_imagelist
class VideoProcessor(VaeImageProcessor):
r"""Simple video processor."""
def preprocess_video(self, video, height: int | None = None, width: int | None = None, **kwargs) -> torch.Tensor:
r"""
Preprocesses input video(s). Keyword arguments will be forwarded to `VaeImageProcessor.preprocess`.
Args:
video (`list[PIL.Image]`, `list[list[PIL.Image]]`, `torch.Tensor`, `np.array`, `list[torch.Tensor]`, `list[np.array]`):
The input video. It can be one of the following:
* list of the PIL images.
* list of list of PIL images.
* 4D Torch tensors (expected shape for each tensor `(num_frames, num_channels, height, width)`).
* 4D NumPy arrays (expected shape for each array `(num_frames, height, width, num_channels)`).
* list of 4D Torch tensors (expected shape for each tensor `(num_frames, num_channels, height,
width)`).
* list of 4D NumPy arrays (expected shape for each array `(num_frames, height, width, num_channels)`).
* 5D NumPy arrays: expected shape for each array `(batch_size, num_frames, height, width,
num_channels)`.
* 5D Torch tensors: expected shape for each array `(batch_size, num_frames, num_channels, height,
width)`.
height (`int`, *optional*, defaults to `None`):
The height in preprocessed frames of the video. If `None`, will use the `get_default_height_width()` to
get default height.
width (`int`, *optional*`, defaults to `None`):
The width in preprocessed frames of the video. If `None`, will use get_default_height_width()` to get
the default width.
Returns:
`torch.Tensor` of shape `(batch_size, num_channels, num_frames, height, width)`:
A 5D tensor holding the batched channels-first video(s).
"""
if isinstance(video, list) and isinstance(video[0], np.ndarray) and video[0].ndim == 5:
warnings.warn(
"Passing `video` as a list of 5d np.ndarray is deprecated."
"Please concatenate the list along the batch dimension and pass it as a single 5d np.ndarray",
FutureWarning,
)
video = np.concatenate(video, axis=0)
if isinstance(video, list) and isinstance(video[0], torch.Tensor) and video[0].ndim == 5:
warnings.warn(
"Passing `video` as a list of 5d torch.Tensor is deprecated."
"Please concatenate the list along the batch dimension and pass it as a single 5d torch.Tensor",
FutureWarning,
)
video = torch.cat(video, axis=0)
# ensure the input is a list of videos:
# - if it is a batch of videos (5d torch.Tensor or np.ndarray), it is converted to a list of videos (a list of 4d torch.Tensor or np.ndarray)
# - if it is a single video, it is converted to a list of one video.
if isinstance(video, (np.ndarray, torch.Tensor)) and video.ndim == 5:
video = list(video)
elif isinstance(video, list) and is_valid_image(video[0]) or is_valid_image_imagelist(video):
video = [video]
elif isinstance(video, list) and is_valid_image_imagelist(video[0]):
video = video
else:
raise ValueError(
"Input is in incorrect format. Currently, we only support numpy.ndarray, torch.Tensor, PIL.Image.Image"
)
video = torch.stack([self.preprocess(img, height=height, width=width, **kwargs) for img in video], dim=0)
# move the number of channels before the number of frames.
video = video.permute(0, 2, 1, 3, 4)
return video
def postprocess_video(
self, video: torch.Tensor, output_type: str = "np", **kwargs
) -> np.ndarray | torch.Tensor | list[PIL.Image.Image]:
r"""
Converts a video tensor to a list of frames for export. Keyword arguments will be forwarded to
`VaeImageProcessor.postprocess`.
Args:
video (`torch.Tensor`): The video as a tensor.
output_type (`str`, defaults to `"np"`): Output type of the postprocessed `video` tensor.
"""
batch_size = video.shape[0]
outputs = []
for batch_idx in range(batch_size):
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
batch_output = self.postprocess(batch_vid, output_type, **kwargs)
outputs.append(batch_output)
if output_type == "np":
outputs = np.stack(outputs)
elif output_type == "pt":
outputs = torch.stack(outputs)
elif not output_type == "pil":
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
return outputs
@staticmethod
def classify_height_width_bin(height: int, width: int, ratios: dict) -> tuple[int, int]:
r"""
Returns the binned height and width based on the aspect ratio.
Args:
height (`int`): The height of the image.
width (`int`): The width of the image.
ratios (`dict`): A dictionary where keys are aspect ratios and values are tuples of (height, width).
Returns:
`tuple[int, int]`: The closest binned height and width.
"""
ar = float(height / width)
closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
default_hw = ratios[closest_ratio]
return int(default_hw[0]), int(default_hw[1])
@staticmethod
def resize_and_crop_tensor(samples: torch.Tensor, new_width: int, new_height: int) -> torch.Tensor:
r"""
Resizes and crops a tensor of videos to the specified dimensions.
Args:
samples (`torch.Tensor`):
A tensor of shape (N, C, T, H, W) where N is the batch size, C is the number of channels, T is the
number of frames, H is the height, and W is the width.
new_width (`int`): The desired width of the output videos.
new_height (`int`): The desired height of the output videos.
Returns:
`torch.Tensor`: A tensor containing the resized and cropped videos.
"""
orig_height, orig_width = samples.shape[3], samples.shape[4]
# Check if resizing is needed
if orig_height != new_height or orig_width != new_width:
ratio = max(new_height / orig_height, new_width / orig_width)
resized_width = int(orig_width * ratio)
resized_height = int(orig_height * ratio)
# Reshape to (N*T, C, H, W) for interpolation
n, c, t, h, w = samples.shape
samples = samples.permute(0, 2, 1, 3, 4).reshape(n * t, c, h, w)
# Resize
samples = F.interpolate(
samples, size=(resized_height, resized_width), mode="bilinear", align_corners=False
)
# Center Crop
start_x = (resized_width - new_width) // 2
end_x = start_x + new_width
start_y = (resized_height - new_height) // 2
end_y = start_y + new_height
samples = samples[:, :, start_y:end_y, start_x:end_x]
# Reshape back to (N, C, T, H, W)
samples = samples.reshape(n, t, c, new_height, new_width).permute(0, 2, 1, 3, 4)
return samples