Skip to content

Commit 192ab46

Browse files
authored
hparams: allow setting trial ID (#2442)
Summary: Resolves #2440. See #1998 for discussion. Test Plan: The hparams demo still does not specify trial IDs (intentionally, as this is the usual path). But apply the following patch— ```diff diff --git a/tensorboard/plugins/hparams/hparams_demo.py b/tensorboard/plugins/hparams/hparams_demo.py index ac4e762..38b2b122 100644 --- a/tensorboard/plugins/hparams/hparams_demo.py +++ b/tensorboard/plugins/hparams/hparams_demo.py @@ -160,7 +160,7 @@ def model_fn(hparams, seed): return model -def run(data, base_logdir, session_id, hparams): +def run(data, base_logdir, session_id, trial_id, hparams): """Run a training/validation session. Flags must have been parsed for this function to behave. @@ -179,7 +179,7 @@ def run(data, base_logdir, session_id, hparams): update_freq=flags.FLAGS.summary_freq, profile_batch=0, # workaround for issue #2084 ) - hparams_callback = hp.KerasCallback(logdir, hparams) + hparams_callback = hp.KerasCallback(logdir, hparams, trial_id=trial_id) ((x_train, y_train), (x_test, y_test)) = data result = model.fit( x=x_train, @@ -235,6 +235,7 @@ def run_all(logdir, verbose=False): data=data, base_logdir=logdir, session_id=session_id, + trial_id="trial-%d" % group_index, hparams=hparams, ) ``` —and then run `//tensorboard/plugins/hparams:hparams_demo`, and observe that the HParams dashboard renders a “Trial ID” column with the specified IDs: ![Screenshot of new version of HParams dashboard][1] [1]: https://user-images.githubusercontent.com/4317806/61491024-1fb01280-a963-11e9-8a47-35e0a01f3691.png wchargin-branch: hparams-trial-id
1 parent ae5353a commit 192ab46

File tree

5 files changed

+59
-20
lines changed

5 files changed

+59
-20
lines changed

tensorboard/plugins/hparams/keras.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class Callback(tf.keras.callbacks.Callback):
3636
NOTE: This callback only works in TensorFlow eager mode.
3737
"""
3838

39-
def __init__(self, writer, hparams):
39+
def __init__(self, writer, hparams, trial_id=None):
4040
"""Create a callback for logging hyperparameters to TensorBoard.
4141
4242
As with the standard `tf.keras.callbacks.TensorBoard` class, each
@@ -51,6 +51,9 @@ def __init__(self, writer, hparams):
5151
in an experiment, or the `HParam` objects themselves. Values
5252
should be Python `bool`, `int`, `float`, or `string` values,
5353
depending on the type of the hyperparameter.
54+
trial_id: An optional `str` ID for the set of hyperparameter
55+
values used in this trial. Defaults to a hash of the
56+
hyperparameters.
5457
5558
Raises:
5659
ValueError: If two entries in `hparams` share the same
@@ -60,7 +63,8 @@ def __init__(self, writer, hparams):
6063
# timestamp is correct. But create a "dry-run" first to fail fast in
6164
# case the `hparams` are invalid.
6265
self._hparams = dict(hparams)
63-
summary_v2.hparams_pb(self._hparams)
66+
self._trial_id = trial_id
67+
summary_v2.hparams_pb(self._hparams, trial_id=self._trial_id)
6468
if writer is None:
6569
raise TypeError("writer must be a `SummaryWriter` or `str`, not None")
6670
elif isinstance(writer, str):
@@ -82,7 +86,7 @@ def _get_writer(self):
8286
def on_train_begin(self, logs=None):
8387
del logs # unused
8488
with self._get_writer().as_default():
85-
summary_v2.hparams(self._hparams)
89+
summary_v2.hparams(self._hparams, trial_id=self._trial_id)
8690

8791
def on_train_end(self, logs=None):
8892
del logs # unused

tensorboard/plugins/hparams/keras_test.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ def _initialize_model(self, writer):
5656
tf.keras.layers.Dense(1, activation="sigmoid"),
5757
])
5858
self.model.compile(loss="mse", optimizer=self.hparams["optimizer"])
59-
self.callback = keras.Callback(writer, self.hparams)
59+
self.trial_id = "my_trial"
60+
self.callback = keras.Callback(writer, self.hparams, trial_id=self.trial_id)
6061

6162
def test_eager(self):
6263
def mock_time():
@@ -99,13 +100,11 @@ def mock_time():
99100
start_pb.start_time_secs = 1234.5
100101
end_pb.end_time_secs = 6789.0
101102

102-
start_pb.group_name = "do_not_care"
103-
104103
expected_start_pb = plugin_data_pb2.SessionStartInfo()
105104
text_format.Merge(
106105
"""
107106
start_time_secs: 1234.5
108-
group_name: "do_not_care"
107+
group_name: "my_trial"
109108
hparams {
110109
key: "optimizer"
111110
value {
@@ -186,6 +185,11 @@ def test_duplicate_hparam_names_from_two_objects(self):
186185
self, ValueError, "multiple values specified for hparam 'foo'"):
187186
keras.Callback(self.get_temp_dir(), hparams)
188187

188+
def test_invalid_trial_id(self):
189+
with six.assertRaisesRegex(
190+
self, TypeError, "`trial_id` should be a `str`, but got: 12"):
191+
keras.Callback(self.get_temp_dir(), {}, trial_id=12)
192+
189193

190194
if __name__ == "__main__":
191195
tf.test.main()

tensorboard/plugins/hparams/summary_v2.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from tensorboard.plugins.hparams import plugin_data_pb2
3737

3838

39-
def hparams(hparams, start_time_secs=None):
39+
def hparams(hparams, trial_id=None, start_time_secs=None):
4040
# NOTE: Keep docs in sync with `hparams_pb` below.
4141
"""Write hyperparameter values for a single trial.
4242
@@ -46,6 +46,8 @@ def hparams(hparams, start_time_secs=None):
4646
experiment, or the `HParam` objects themselves. Values should be
4747
Python `bool`, `int`, `float`, or `string` values, depending on
4848
the type of the hyperparameter.
49+
trial_id: An optional `str` ID for the set of hyperparameter values
50+
used in this trial. Defaults to a hash of the hyperparameters.
4951
start_time_secs: The time that this trial started training, as
5052
seconds since epoch. Defaults to the current time.
5153
@@ -55,12 +57,13 @@ def hparams(hparams, start_time_secs=None):
5557
"""
5658
pb = hparams_pb(
5759
hparams=hparams,
60+
trial_id=trial_id,
5861
start_time_secs=start_time_secs,
5962
)
6063
return _write_summary("hparams", pb)
6164

6265

63-
def hparams_pb(hparams, start_time_secs=None):
66+
def hparams_pb(hparams, trial_id=None, start_time_secs=None):
6467
# NOTE: Keep docs in sync with `hparams` above.
6568
"""Create a summary encoding hyperparameter values for a single trial.
6669
@@ -70,6 +73,8 @@ def hparams_pb(hparams, start_time_secs=None):
7073
experiment, or the `HParam` objects themselves. Values should be
7174
Python `bool`, `int`, `float`, or `string` values, depending on
7275
the type of the hyperparameter.
76+
trial_id: An optional `str` ID for the set of hyperparameter values
77+
used in this trial. Defaults to a hash of the hyperparameters.
7378
start_time_secs: The time that this trial started training, as
7479
seconds since epoch. Defaults to the current time.
7580
@@ -79,7 +84,7 @@ def hparams_pb(hparams, start_time_secs=None):
7984
if start_time_secs is None:
8085
start_time_secs = time.time()
8186
hparams = _normalize_hparams(hparams)
82-
group_name = _derive_session_group_name(hparams)
87+
group_name = _derive_session_group_name(trial_id, hparams)
8388

8489
session_start_info = plugin_data_pb2.SessionStartInfo(
8590
group_name=group_name,
@@ -199,7 +204,11 @@ def _normalize_hparams(hparams):
199204
return result
200205

201206

202-
def _derive_session_group_name(hparams):
207+
def _derive_session_group_name(trial_id, hparams):
208+
if trial_id is not None:
209+
if not isinstance(trial_id, six.string_types):
210+
raise TypeError("`trial_id` should be a `str`, but got: %r" % (trial_id,))
211+
return trial_id
203212
# Use `json.dumps` rather than `str` to ensure invariance under string
204213
# type (incl. across Python versions) and dict iteration order.
205214
jparams = json.dumps(hparams, sort_keys=True, separators=(",", ":"))

tensorboard/plugins/hparams/summary_v2_test.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def setUp(self):
8282
"dropout": 0.3,
8383
}
8484
self.start_time_secs = 123.45
85-
self.group_name = "big_sha"
85+
self.trial_id = "psl27"
8686

8787
self.expected_session_start_pb = plugin_data_pb2.SessionStartInfo()
8888
text_format.Merge(
@@ -93,13 +93,13 @@ def setUp(self):
9393
hparams { key: "who_knows_what" value { string_value: "???" } }
9494
hparams { key: "magic" value { bool_value: true } }
9595
hparams { key: "dropout" value { number_value: 0.3 } }
96-
group_name: "big_sha" # we'll ignore this field when asserting equality
9796
""",
9897
self.expected_session_start_pb,
9998
)
99+
self.expected_session_start_pb.group_name = self.trial_id
100100
self.expected_session_start_pb.start_time_secs = self.start_time_secs
101101

102-
def _check_summary(self, summary_pb):
102+
def _check_summary(self, summary_pb, check_group_name=False):
103103
"""Test that a summary contains exactly the expected hparams PB."""
104104
values = summary_pb.value
105105
self.assertEqual(len(values), 1, values)
@@ -110,18 +110,27 @@ def _check_summary(self, summary_pb):
110110
)
111111
plugin_content = actual_value.metadata.plugin_data.content
112112
info_pb = metadata.parse_session_start_info_plugin_data(plugin_content)
113-
# Ignore the `group_name` field; its properties are checked separately.
114-
info_pb.group_name = self.expected_session_start_pb.group_name
113+
# Usually ignore the `group_name` field; its properties are checked
114+
# separately.
115+
if not check_group_name:
116+
info_pb.group_name = self.expected_session_start_pb.group_name
115117
self.assertEqual(info_pb, self.expected_session_start_pb)
116118

117-
def _check_logdir(self, logdir):
119+
def _check_logdir(self, logdir, check_group_name=False):
118120
"""Test that the hparams summary was written to `logdir`."""
119-
self._check_summary(_get_unique_summary(self, logdir))
121+
self._check_summary(
122+
_get_unique_summary(self, logdir),
123+
check_group_name=check_group_name,
124+
)
120125

121126
@requires_tf
122127
def test_eager(self):
123128
with tf.compat.v2.summary.create_file_writer(self.logdir).as_default():
124-
result = hp.hparams(self.hparams, start_time_secs=self.start_time_secs)
129+
result = hp.hparams(
130+
self.hparams,
131+
trial_id=self.trial_id,
132+
start_time_secs=self.start_time_secs,
133+
)
125134
self.assertTrue(result)
126135
self._check_logdir(self.logdir)
127136

@@ -152,6 +161,19 @@ def test_pb_is_tensorboard_copy_of_proto(self):
152161
if tf is not None:
153162
self.assertNotIsInstance(result, tf.compat.v1.Summary)
154163

164+
def test_pb_explicit_trial_id(self):
165+
result = hp.hparams_pb(
166+
self.hparams,
167+
trial_id=self.trial_id,
168+
start_time_secs=self.start_time_secs,
169+
)
170+
self._check_summary(result, check_group_name=True)
171+
172+
def test_pb_invalid_trial_id(self):
173+
with six.assertRaisesRegex(
174+
self, TypeError, "`trial_id` should be a `str`, but got: 12"):
175+
hp.hparams_pb(self.hparams, trial_id=12)
176+
155177
def assert_hparams_summaries_equal(self, summary_1, summary_2):
156178
def canonical(summary):
157179
"""Return a canonical form for `summary`.

tensorboard/plugins/hparams/tf_hparams_table_view/tf-hparams-table-view.html

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
items="[[sessionGroups]]">
3737
<vaadin-grid-column flex-grow="0" width="10em" resizable>
3838
<template class="header">
39-
<div class="table-header table-cell">Session Group Name.</div>
39+
<div class="table-header table-cell">Trial ID</div>
4040
</template>
4141
<template>
4242
<div class="table-cell">[[item.name]]</div>

0 commit comments

Comments
 (0)