@@ -41,12 +41,20 @@ def main(_):
4141 hvd .init ()
4242
4343 FLAGS = PARSER .parse_args ()
44- backends = [StdOutBackend (Verbosity .DEFAULT )]
4544
46- if FLAGS .log_dir :
47- backends += [JSONStreamBackend (Verbosity .DEFAULT , FLAGS .log_dir )]
45+ backends = []
46+
47+ if hvd .rank () == 0 :
48+ backends += [StdOutBackend (Verbosity .DEFAULT )]
49+
50+ if FLAGS .log_dir :
51+ backends += [JSONStreamBackend (Verbosity .DEFAULT , FLAGS .log_dir )]
4852
4953 DLLogger .init (backends = backends )
54+
55+ for key in vars (FLAGS ):
56+ DLLogger .log (step = "PARAMETER" , data = {str (key ): vars (FLAGS )[key ]})
57+
5058 os .environ ['CUDA_CACHE_DISABLE' ] = '0'
5159
5260 os .environ ['HOROVOD_GPU_ALLREDUCE' ] = 'NCCL'
@@ -65,9 +73,6 @@ def main(_):
6573 os .environ ['TF_AUTOTUNE_THRESHOLD' ] = '2'
6674 os .environ ['TF_DISABLE_NVTX_RANGES' ] = '1'
6775
68- if hvd .rank () == 0 :
69- DLLogger .log (step = tuple (), data = {"mixed_precision" : "ENABLED" if FLAGS .use_amp else "DISABLED" })
70-
7176 dataset = MSDDataset (json_path = os .path .join (FLAGS .data_dir , 'dataset.json' ),
7277 dst_size = FLAGS .input_shape ,
7378 seed = FLAGS .seed ,
@@ -85,6 +90,9 @@ def main(_):
8590 config .gpu_options .allow_growth = True
8691 config .gpu_options .visible_device_list = str (hvd .local_rank ())
8792
93+ if FLAGS .use_amp :
94+ config .graph_options .rewrite_options .auto_mixed_precision = 1
95+
8896 run_config = tf .estimator .RunConfig (
8997 save_summary_steps = None ,
9098 save_checkpoints_steps = None if FLAGS .benchmark else dataset .train_steps * FLAGS .train_epoch ,
@@ -112,29 +120,21 @@ def main(_):
112120 if hvd .rank () == 0 :
113121 train_hooks += [TrainHook (FLAGS .log_every , DLLogger )]
114122
115- DLLogger .log (step = tuple (), data = {"training" : "START" })
116-
117123 estimator .train (
118124 input_fn = lambda : dataset .train_fn (FLAGS .augment ),
119125 steps = steps ,
120126 hooks = train_hooks )
121127
122- DLLogger .log (step = tuple (), data = {"training" : "FINISHED" })
123-
124128 if 'evaluate' in FLAGS .exec_mode :
125129 if hvd .rank () == 0 :
126130 if FLAGS .train_split >= 1.0 :
127131 raise ValueError ("Missing argument: --train_split < 1.0" )
128132
129- DLLogger .log (step = tuple (), data = {"evaluating" : "START" })
130-
131133 result = estimator .evaluate (
132134 input_fn = dataset .eval_fn ,
133135 steps = dataset .eval_steps ,
134136 hooks = [])
135137
136- DLLogger .log (step = tuple (), data = {"evaluating" : "FINISH" })
137-
138138 DLLogger .log (step = tuple (), data = {'background_dice' : str (result ['background dice' ])})
139139 DLLogger .log (step = tuple (), data = {'anterior_dice' : str (result ['Anterior dice' ])})
140140 DLLogger .log (step = tuple (), data = {'posterior_dice' : str (result ['Posterior dice' ])})
@@ -163,3 +163,4 @@ def main(_):
163163
164164if __name__ == '__main__' :
165165 tf .compat .v1 .app .run ()
166+
0 commit comments