Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions test/test_tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,13 @@ def test_hparams_string(self):
mt = {'accuracy': 0.1}
self.assertTrue(compare_proto(summary.hparams(hp, mt), self))

def test_hparams_domain_discrete(self):
hp = {"lr": 0.1, "bool_var": True, "string_var": "hi"}
mt = {"accuracy": 0.1}
hp_domain = {"lr": [0.1], "bool_var": [True], "string_var": ["hi"]}
# only smoke test. Because protobuf map serialization is nondeterministic.
summary.hparams(hp, mt, hp_domain)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you construct the expected proto separately and compare it to whats expected from summary.hparams?


def test_mesh(self):
v = np.array([[[1, 1, 1], [-1, -1, 1], [1, -1, -1], [-1, 1, -1]]], dtype=float)
c = np.array([[[255, 0, 0], [0, 255, 0], [0, 0, 255], [255, 0, 255]]], dtype=int)
Expand Down
77 changes: 73 additions & 4 deletions torch/utils/tensorboard/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# pylint: disable=unused-import
from six.moves import range

from google.protobuf import struct_pb2
from tensorboard.compat.proto.summary_pb2 import Summary
from tensorboard.compat.proto.summary_pb2 import HistogramProto
from tensorboard.compat.proto.summary_pb2 import SummaryMetadata
Expand Down Expand Up @@ -50,7 +51,7 @@ def _draw_single_box(image, xmin, ymin, xmax, ymax, display_str, color='black',
return image


def hparams(hparam_dict=None, metric_dict=None):
def hparams(hparam_dict=None, metric_dict=None, hparam_domain_discrete=None):
"""Outputs three `Summary` protocol buffers needed by hparams plugin.
`Experiment` keeps the metadata of an experiment, such as the name of the
hyperparameters and the name of the metrics.
Expand All @@ -62,6 +63,8 @@ def hparams(hparam_dict=None, metric_dict=None):
and their values.
metric_dict: A dictionary that contains names of the metrics
and their values.
hparam_domain_discrete: (Optional[Dict[str, List[Any]]]) A dictionary that
contains names of the hyperparameters and all discrete values they can hold

Returns:
The `Summary` protobufs for Experiment, SessionStartInfo and
Expand Down Expand Up @@ -99,6 +102,21 @@ def hparams(hparam_dict=None, metric_dict=None):
logging.warning('parameter: metric_dict should be a dictionary, nothing logged.')
raise TypeError('parameter: metric_dict should be a dictionary, nothing logged.')

hparam_domain_discrete = hparam_domain_discrete or {}
if not isinstance(hparam_domain_discrete, dict):
raise TypeError(
"parameter: hparam_domain_discrete should be a dictionary, nothing logged."
)
for k, v in hparam_domain_discrete.items():
if (
k not in hparam_dict
or not isinstance(v, list)
or not all(isinstance(d, type(hparam_dict[k])) for d in v)
):
raise TypeError(
"parameter: hparam_domain_discrete[{}] should be a list of same type as "
"hparam_dict[{}].".format(k, k)
)
hps = []


Expand All @@ -108,17 +126,68 @@ def hparams(hparam_dict=None, metric_dict=None):
continue
if isinstance(v, int) or isinstance(v, float):
ssi.hparams[k].number_value = v
hps.append(HParamInfo(name=k, type=DataType.Value("DATA_TYPE_FLOAT64")))

if k in hparam_domain_discrete:
domain_discrete = struct_pb2.ListValue(
values=[
struct_pb2.Value(number_value=d)
for d in hparam_domain_discrete[k]
]
)
else:
domain_discrete = None

hps.append(
HParamInfo(
name=k,
type=DataType.Value("DATA_TYPE_FLOAT64"),
domain_discrete=domain_discrete,
)
)
continue

if isinstance(v, string_types):
ssi.hparams[k].string_value = v
hps.append(HParamInfo(name=k, type=DataType.Value("DATA_TYPE_STRING")))

if k in hparam_domain_discrete:
domain_discrete = struct_pb2.ListValue(
values=[
struct_pb2.Value(string_value=d)
for d in hparam_domain_discrete[k]
]
)
else:
domain_discrete = None

hps.append(
HParamInfo(
name=k,
type=DataType.Value("DATA_TYPE_STRING"),
domain_discrete=domain_discrete,
)
)
continue

if isinstance(v, bool):
ssi.hparams[k].bool_value = v
hps.append(HParamInfo(name=k, type=DataType.Value("DATA_TYPE_BOOL")))

if k in hparam_domain_discrete:
domain_discrete = struct_pb2.ListValue(
values=[
struct_pb2.Value(bool_value=d)
for d in hparam_domain_discrete[k]
]
)
else:
domain_discrete = None

hps.append(
HParamInfo(
name=k,
type=DataType.Value("DATA_TYPE_BOOL"),
domain_discrete=domain_discrete,
)
)
continue

if isinstance(v, torch.Tensor):
Expand Down
8 changes: 6 additions & 2 deletions torch/utils/tensorboard/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,9 @@ def get_logdir(self):
"""Returns the directory where event files will be written."""
return self.log_dir

def add_hparams(self, hparam_dict, metric_dict, run_name=None):
def add_hparams(
self, hparam_dict, metric_dict, hparam_domain_discrete=None, run_name=None
):
"""Add a set of hyperparameters to be compared in TensorBoard.

Args:
Expand All @@ -281,6 +283,8 @@ def add_hparams(self, hparam_dict, metric_dict, run_name=None):
here should be unique in the tensorboard record. Otherwise the value
you added by ``add_scalar`` will be displayed in hparam plugin. In most
cases, this is unwanted.
hparam_domain_discrete: (Optional[Dict[str, List[Any]]]) A dictionary that
contains names of the hyperparameters and all discrete values they can hold
run_name (str): Name of the run, to be included as part of the logdir.
If unspecified, will use current timestamp.

Expand All @@ -301,7 +305,7 @@ def add_hparams(self, hparam_dict, metric_dict, run_name=None):
torch._C._log_api_usage_once("tensorboard.logging.add_hparams")
if type(hparam_dict) is not dict or type(metric_dict) is not dict:
raise TypeError('hparam_dict and metric_dict should be dictionary.')
exp, ssi, sei = hparams(hparam_dict, metric_dict)
exp, ssi, sei = hparams(hparam_dict, metric_dict, hparam_domain_discrete)

if not run_name:
run_name = str(time.time())
Expand Down