Skip to content

Commit 171fa76

Browse files
committed
Enable rewrite for inference
Signed-off-by: Pablo Ribalta <pribalta@nvidia.com>
1 parent 8a86378 commit 171fa76

File tree

1 file changed

+15
-14
lines changed
  • TensorFlow/Segmentation/VNet

1 file changed

+15
-14
lines changed

TensorFlow/Segmentation/VNet/main.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,20 @@ def main(_):
4141
hvd.init()
4242

4343
FLAGS = PARSER.parse_args()
44-
backends = [StdOutBackend(Verbosity.DEFAULT)]
4544

46-
if FLAGS.log_dir:
47-
backends += [JSONStreamBackend(Verbosity.DEFAULT, FLAGS.log_dir)]
45+
backends = []
46+
47+
if hvd.rank() == 0:
48+
backends += [StdOutBackend(Verbosity.DEFAULT)]
49+
50+
if FLAGS.log_dir:
51+
backends += [JSONStreamBackend(Verbosity.DEFAULT, FLAGS.log_dir)]
4852

4953
DLLogger.init(backends=backends)
54+
55+
for key in vars(FLAGS):
56+
DLLogger.log(step="PARAMETER", data={str(key): vars(FLAGS)[key]})
57+
5058
os.environ['CUDA_CACHE_DISABLE'] = '0'
5159

5260
os.environ['HOROVOD_GPU_ALLREDUCE'] = 'NCCL'
@@ -65,9 +73,6 @@ def main(_):
6573
os.environ['TF_AUTOTUNE_THRESHOLD'] = '2'
6674
os.environ['TF_DISABLE_NVTX_RANGES'] = '1'
6775

68-
if hvd.rank() == 0:
69-
DLLogger.log(step=tuple(), data={"mixed_precision": "ENABLED" if FLAGS.use_amp else "DISABLED"})
70-
7176
dataset = MSDDataset(json_path=os.path.join(FLAGS.data_dir, 'dataset.json'),
7277
dst_size=FLAGS.input_shape,
7378
seed=FLAGS.seed,
@@ -85,6 +90,9 @@ def main(_):
8590
config.gpu_options.allow_growth = True
8691
config.gpu_options.visible_device_list = str(hvd.local_rank())
8792

93+
if FLAGS.use_amp:
94+
config.graph_options.rewrite_options.auto_mixed_precision = 1
95+
8896
run_config = tf.estimator.RunConfig(
8997
save_summary_steps=None,
9098
save_checkpoints_steps=None if FLAGS.benchmark else dataset.train_steps * FLAGS.train_epoch,
@@ -112,29 +120,21 @@ def main(_):
112120
if hvd.rank() == 0:
113121
train_hooks += [TrainHook(FLAGS.log_every, DLLogger)]
114122

115-
DLLogger.log(step=tuple(), data={"training": "START"})
116-
117123
estimator.train(
118124
input_fn=lambda: dataset.train_fn(FLAGS.augment),
119125
steps=steps,
120126
hooks=train_hooks)
121127

122-
DLLogger.log(step=tuple(), data={"training": "FINISHED"})
123-
124128
if 'evaluate' in FLAGS.exec_mode:
125129
if hvd.rank() == 0:
126130
if FLAGS.train_split >= 1.0:
127131
raise ValueError("Missing argument: --train_split < 1.0")
128132

129-
DLLogger.log(step=tuple(), data={"evaluating": "START"})
130-
131133
result = estimator.evaluate(
132134
input_fn=dataset.eval_fn,
133135
steps=dataset.eval_steps,
134136
hooks=[])
135137

136-
DLLogger.log(step=tuple(), data={"evaluating": "FINISH"})
137-
138138
DLLogger.log(step=tuple(), data={'background_dice': str(result['background dice'])})
139139
DLLogger.log(step=tuple(), data={'anterior_dice': str(result['Anterior dice'])})
140140
DLLogger.log(step=tuple(), data={'posterior_dice': str(result['Posterior dice'])})
@@ -163,3 +163,4 @@ def main(_):
163163

164164
if __name__ == '__main__':
165165
tf.compat.v1.app.run()
166+

0 commit comments

Comments
 (0)