77import torch
88import pytorch_lightning as pl
99import numpy as np
10- # import cv2
10+ import cv2
1111import random
1212import math
1313from torchvision import transforms
@@ -33,9 +33,14 @@ def do_training(hparams, model_constructor):
3333
3434 hparams .sync_batchnorm = True
3535
36+ ttlogger = pl .loggers .TestTubeLogger (
37+ "checkpoints" , name = hparams .exp_name , version = hparams .version
38+ )
39+
3640 hparams .callbacks = make_checkpoint_callbacks (hparams .exp_name , hparams .version )
3741
38- hparams .logger = pl .loggers .TensorBoardLogger ("logs/" )
42+ wblogger = get_wandb_logger (hparams )
43+ hparams .logger = [wblogger , ttlogger ]
3944
4045 trainer = pl .Trainer .from_argparse_args (hparams )
4146 trainer .fit (model )
@@ -160,4 +165,204 @@ def set_resume_parameters(hparams):
160165 else :
161166 version = 0
162167
163- return hparams
168+ return hparams
169+
170+
171+ def get_wandb_logger (hparams ):
172+ exp_dir = f"checkpoints/{ hparams .exp_name } /version_{ hparams .version } /"
173+ id_file = f"{ exp_dir } /wandb_id"
174+
175+ if os .path .exists (id_file ):
176+ with open (id_file ) as f :
177+ hparams .wandb_id = f .read ()
178+ else :
179+ hparams .wandb_id = None
180+
181+ logger = pl .loggers .WandbLogger (
182+ save_dir = "checkpoints" ,
183+ project = hparams .project_name ,
184+ name = hparams .exp_name ,
185+ id = hparams .wandb_id ,
186+ )
187+
188+ if hparams .wandb_id is None :
189+ _ = logger .experiment
190+
191+ if not os .path .exists (exp_dir ):
192+ os .makedirs (exp_dir )
193+
194+ with open (id_file , "w" ) as f :
195+ f .write (logger .version )
196+
197+ return logger
198+
199+
200+ class Resize (object ):
201+ """Resize sample to given size (width, height)."""
202+
203+ def __init__ (
204+ self ,
205+ width ,
206+ height ,
207+ resize_target = True ,
208+ keep_aspect_ratio = False ,
209+ ensure_multiple_of = 1 ,
210+ resize_method = "lower_bound" ,
211+ image_interpolation_method = cv2 .INTER_AREA ,
212+ letter_box = False ,
213+ ):
214+ """Init.
215+
216+ Args:
217+ width (int): desired output width
218+ height (int): desired output height
219+ resize_target (bool, optional):
220+ True: Resize the full sample (image, mask, target).
221+ False: Resize image only.
222+ Defaults to True.
223+ keep_aspect_ratio (bool, optional):
224+ True: Keep the aspect ratio of the input sample.
225+ Output sample might not have the given width and height, and
226+ resize behaviour depends on the parameter 'resize_method'.
227+ Defaults to False.
228+ ensure_multiple_of (int, optional):
229+ Output width and height is constrained to be multiple of this parameter.
230+ Defaults to 1.
231+ resize_method (str, optional):
232+ "lower_bound": Output will be at least as large as the given size.
233+ "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
234+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
235+ Defaults to "lower_bound".
236+ """
237+ self .__width = width
238+ self .__height = height
239+
240+ self .__resize_target = resize_target
241+ self .__keep_aspect_ratio = keep_aspect_ratio
242+ self .__multiple_of = ensure_multiple_of
243+ self .__resize_method = resize_method
244+ self .__image_interpolation_method = image_interpolation_method
245+ self .__letter_box = letter_box
246+
247+ def constrain_to_multiple_of (self , x , min_val = 0 , max_val = None ):
248+ y = (np .round (x / self .__multiple_of ) * self .__multiple_of ).astype (int )
249+
250+ if max_val is not None and y > max_val :
251+ y = (np .floor (x / self .__multiple_of ) * self .__multiple_of ).astype (int )
252+
253+ if y < min_val :
254+ y = (np .ceil (x / self .__multiple_of ) * self .__multiple_of ).astype (int )
255+
256+ return y
257+
258+ def get_size (self , width , height ):
259+ # determine new height and width
260+ scale_height = self .__height / height
261+ scale_width = self .__width / width
262+
263+ if self .__keep_aspect_ratio :
264+ if self .__resize_method == "lower_bound" :
265+ # scale such that output size is lower bound
266+ if scale_width > scale_height :
267+ # fit width
268+ scale_height = scale_width
269+ else :
270+ # fit height
271+ scale_width = scale_height
272+ elif self .__resize_method == "upper_bound" :
273+ # scale such that output size is upper bound
274+ if scale_width < scale_height :
275+ # fit width
276+ scale_height = scale_width
277+ else :
278+ # fit height
279+ scale_width = scale_height
280+ elif self .__resize_method == "minimal" :
281+ # scale as least as possbile
282+ if abs (1 - scale_width ) < abs (1 - scale_height ):
283+ # fit width
284+ scale_height = scale_width
285+ else :
286+ # fit height
287+ scale_width = scale_height
288+ else :
289+ raise ValueError (
290+ f"resize_method { self .__resize_method } not implemented"
291+ )
292+
293+ if self .__resize_method == "lower_bound" :
294+ new_height = self .constrain_to_multiple_of (
295+ scale_height * height , min_val = self .__height
296+ )
297+ new_width = self .constrain_to_multiple_of (
298+ scale_width * width , min_val = self .__width
299+ )
300+ elif self .__resize_method == "upper_bound" :
301+ new_height = self .constrain_to_multiple_of (
302+ scale_height * height , max_val = self .__height
303+ )
304+ new_width = self .constrain_to_multiple_of (
305+ scale_width * width , max_val = self .__width
306+ )
307+ elif self .__resize_method == "minimal" :
308+ new_height = self .constrain_to_multiple_of (scale_height * height )
309+ new_width = self .constrain_to_multiple_of (scale_width * width )
310+ else :
311+ raise ValueError (f"resize_method { self .__resize_method } not implemented" )
312+
313+ return (new_width , new_height )
314+
315+ def make_letter_box (self , sample ):
316+ top = bottom = (self .__height - sample .shape [0 ]) // 2
317+ left = right = (self .__width - sample .shape [1 ]) // 2
318+ sample = cv2 .copyMakeBorder (
319+ sample , top , bottom , left , right , cv2 .BORDER_CONSTANT , None , 0
320+ )
321+ return sample
322+
323+ def __call__ (self , sample ):
324+ width , height = self .get_size (
325+ sample ["image" ].shape [1 ], sample ["image" ].shape [0 ]
326+ )
327+
328+ # resize sample
329+ sample ["image" ] = cv2 .resize (
330+ sample ["image" ],
331+ (width , height ),
332+ interpolation = self .__image_interpolation_method ,
333+ )
334+
335+ if self .__letter_box :
336+ sample ["image" ] = self .make_letter_box (sample ["image" ])
337+
338+ if self .__resize_target :
339+ if "disparity" in sample :
340+ sample ["disparity" ] = cv2 .resize (
341+ sample ["disparity" ],
342+ (width , height ),
343+ interpolation = cv2 .INTER_NEAREST ,
344+ )
345+
346+ if self .__letter_box :
347+ sample ["disparity" ] = self .make_letter_box (sample ["disparity" ])
348+
349+ if "depth" in sample :
350+ sample ["depth" ] = cv2 .resize (
351+ sample ["depth" ], (width , height ), interpolation = cv2 .INTER_NEAREST
352+ )
353+
354+ if self .__letter_box :
355+ sample ["depth" ] = self .make_letter_box (sample ["depth" ])
356+
357+ sample ["mask" ] = cv2 .resize (
358+ sample ["mask" ].astype (np .float32 ),
359+ (width , height ),
360+ interpolation = cv2 .INTER_NEAREST ,
361+ )
362+
363+ if self .__letter_box :
364+ sample ["mask" ] = self .make_letter_box (sample ["mask" ])
365+
366+ sample ["mask" ] = sample ["mask" ].astype (bool )
367+
368+ return sample
0 commit comments