Skip to content
This repository was archived by the owner on Dec 18, 2024. It is now read-only.

Commit 9d063b1

Browse files
authored
Update utils.py
1 parent f06427e commit 9d063b1

File tree

1 file changed

+208
-3
lines changed

1 file changed

+208
-3
lines changed

utils.py

Lines changed: 208 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch
88
import pytorch_lightning as pl
99
import numpy as np
10-
#import cv2
10+
import cv2
1111
import random
1212
import math
1313
from torchvision import transforms
@@ -33,9 +33,14 @@ def do_training(hparams, model_constructor):
3333

3434
hparams.sync_batchnorm = True
3535

36+
ttlogger = pl.loggers.TestTubeLogger(
37+
"checkpoints", name=hparams.exp_name, version=hparams.version
38+
)
39+
3640
hparams.callbacks = make_checkpoint_callbacks(hparams.exp_name, hparams.version)
3741

38-
hparams.logger = pl.loggers.TensorBoardLogger("logs/")
42+
wblogger = get_wandb_logger(hparams)
43+
hparams.logger = [wblogger, ttlogger]
3944

4045
trainer = pl.Trainer.from_argparse_args(hparams)
4146
trainer.fit(model)
@@ -160,4 +165,204 @@ def set_resume_parameters(hparams):
160165
else:
161166
version = 0
162167

163-
return hparams
168+
return hparams
169+
170+
171+
def get_wandb_logger(hparams):
172+
exp_dir = f"checkpoints/{hparams.exp_name}/version_{hparams.version}/"
173+
id_file = f"{exp_dir}/wandb_id"
174+
175+
if os.path.exists(id_file):
176+
with open(id_file) as f:
177+
hparams.wandb_id = f.read()
178+
else:
179+
hparams.wandb_id = None
180+
181+
logger = pl.loggers.WandbLogger(
182+
save_dir="checkpoints",
183+
project=hparams.project_name,
184+
name=hparams.exp_name,
185+
id=hparams.wandb_id,
186+
)
187+
188+
if hparams.wandb_id is None:
189+
_ = logger.experiment
190+
191+
if not os.path.exists(exp_dir):
192+
os.makedirs(exp_dir)
193+
194+
with open(id_file, "w") as f:
195+
f.write(logger.version)
196+
197+
return logger
198+
199+
200+
class Resize(object):
201+
"""Resize sample to given size (width, height)."""
202+
203+
def __init__(
204+
self,
205+
width,
206+
height,
207+
resize_target=True,
208+
keep_aspect_ratio=False,
209+
ensure_multiple_of=1,
210+
resize_method="lower_bound",
211+
image_interpolation_method=cv2.INTER_AREA,
212+
letter_box=False,
213+
):
214+
"""Init.
215+
216+
Args:
217+
width (int): desired output width
218+
height (int): desired output height
219+
resize_target (bool, optional):
220+
True: Resize the full sample (image, mask, target).
221+
False: Resize image only.
222+
Defaults to True.
223+
keep_aspect_ratio (bool, optional):
224+
True: Keep the aspect ratio of the input sample.
225+
Output sample might not have the given width and height, and
226+
resize behaviour depends on the parameter 'resize_method'.
227+
Defaults to False.
228+
ensure_multiple_of (int, optional):
229+
Output width and height is constrained to be multiple of this parameter.
230+
Defaults to 1.
231+
resize_method (str, optional):
232+
"lower_bound": Output will be at least as large as the given size.
233+
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
234+
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
235+
Defaults to "lower_bound".
236+
"""
237+
self.__width = width
238+
self.__height = height
239+
240+
self.__resize_target = resize_target
241+
self.__keep_aspect_ratio = keep_aspect_ratio
242+
self.__multiple_of = ensure_multiple_of
243+
self.__resize_method = resize_method
244+
self.__image_interpolation_method = image_interpolation_method
245+
self.__letter_box = letter_box
246+
247+
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
248+
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
249+
250+
if max_val is not None and y > max_val:
251+
y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
252+
253+
if y < min_val:
254+
y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
255+
256+
return y
257+
258+
def get_size(self, width, height):
259+
# determine new height and width
260+
scale_height = self.__height / height
261+
scale_width = self.__width / width
262+
263+
if self.__keep_aspect_ratio:
264+
if self.__resize_method == "lower_bound":
265+
# scale such that output size is lower bound
266+
if scale_width > scale_height:
267+
# fit width
268+
scale_height = scale_width
269+
else:
270+
# fit height
271+
scale_width = scale_height
272+
elif self.__resize_method == "upper_bound":
273+
# scale such that output size is upper bound
274+
if scale_width < scale_height:
275+
# fit width
276+
scale_height = scale_width
277+
else:
278+
# fit height
279+
scale_width = scale_height
280+
elif self.__resize_method == "minimal":
281+
# scale as least as possbile
282+
if abs(1 - scale_width) < abs(1 - scale_height):
283+
# fit width
284+
scale_height = scale_width
285+
else:
286+
# fit height
287+
scale_width = scale_height
288+
else:
289+
raise ValueError(
290+
f"resize_method {self.__resize_method} not implemented"
291+
)
292+
293+
if self.__resize_method == "lower_bound":
294+
new_height = self.constrain_to_multiple_of(
295+
scale_height * height, min_val=self.__height
296+
)
297+
new_width = self.constrain_to_multiple_of(
298+
scale_width * width, min_val=self.__width
299+
)
300+
elif self.__resize_method == "upper_bound":
301+
new_height = self.constrain_to_multiple_of(
302+
scale_height * height, max_val=self.__height
303+
)
304+
new_width = self.constrain_to_multiple_of(
305+
scale_width * width, max_val=self.__width
306+
)
307+
elif self.__resize_method == "minimal":
308+
new_height = self.constrain_to_multiple_of(scale_height * height)
309+
new_width = self.constrain_to_multiple_of(scale_width * width)
310+
else:
311+
raise ValueError(f"resize_method {self.__resize_method} not implemented")
312+
313+
return (new_width, new_height)
314+
315+
def make_letter_box(self, sample):
316+
top = bottom = (self.__height - sample.shape[0]) // 2
317+
left = right = (self.__width - sample.shape[1]) // 2
318+
sample = cv2.copyMakeBorder(
319+
sample, top, bottom, left, right, cv2.BORDER_CONSTANT, None, 0
320+
)
321+
return sample
322+
323+
def __call__(self, sample):
324+
width, height = self.get_size(
325+
sample["image"].shape[1], sample["image"].shape[0]
326+
)
327+
328+
# resize sample
329+
sample["image"] = cv2.resize(
330+
sample["image"],
331+
(width, height),
332+
interpolation=self.__image_interpolation_method,
333+
)
334+
335+
if self.__letter_box:
336+
sample["image"] = self.make_letter_box(sample["image"])
337+
338+
if self.__resize_target:
339+
if "disparity" in sample:
340+
sample["disparity"] = cv2.resize(
341+
sample["disparity"],
342+
(width, height),
343+
interpolation=cv2.INTER_NEAREST,
344+
)
345+
346+
if self.__letter_box:
347+
sample["disparity"] = self.make_letter_box(sample["disparity"])
348+
349+
if "depth" in sample:
350+
sample["depth"] = cv2.resize(
351+
sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
352+
)
353+
354+
if self.__letter_box:
355+
sample["depth"] = self.make_letter_box(sample["depth"])
356+
357+
sample["mask"] = cv2.resize(
358+
sample["mask"].astype(np.float32),
359+
(width, height),
360+
interpolation=cv2.INTER_NEAREST,
361+
)
362+
363+
if self.__letter_box:
364+
sample["mask"] = self.make_letter_box(sample["mask"])
365+
366+
sample["mask"] = sample["mask"].astype(bool)
367+
368+
return sample

0 commit comments

Comments
 (0)