Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions deeplabcut/benchmark/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import abc
import dataclasses
import warnings
from typing import Iterable
from typing import Tuple

Expand Down Expand Up @@ -89,6 +90,7 @@ def evaluate(self, name: str, on_error="raise"):
root_mean_squared_error = float("nan")
try:
predictions = self.get_predictions(name)
predictions = self._validate_predictions(name, predictions)
mean_avg_precision = self.compute_pose_map(predictions)
root_mean_squared_error = self.compute_pose_rmse(predictions)
except Exception as exception:
Expand All @@ -114,6 +116,27 @@ def evaluate(self, name: str, on_error="raise"):
root_mean_squared_error=root_mean_squared_error,
)

def _validate_predictions(self, name: str, predictions: dict) -> dict:
"""Validates the submitted predictions object
Checks that there is a prediction for each test image, and raises a warning if
that is not the case. Returns only predictions made for test images.
"""
test_images = deeplabcut.benchmark.metrics.load_test_images(
self.ground_truth, self.metadata
)
missing_images = set(test_images) - set(predictions.keys())
if len(missing_images) > 0:
warnings.warn(
f"Missing {len(missing_images)} test images in the predictions for "
f"{name}: {list(missing_images)} Metrics will be computed as if no "
"individuals were detected in those images."
)

return {
img: predictions.get(img, tuple())
for img in test_images
}


@dataclasses.dataclass
class Result:
Expand Down
76 changes: 60 additions & 16 deletions deeplabcut/benchmark/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import os
import pickle
from collections import defaultdict
from typing import List, Optional

import numpy as np
import pandas as pd
Expand All @@ -33,7 +34,7 @@
from deeplabcut.utils.conversioncode import guarantee_multiindex_rows


def _format_gt_data(h5file):
def _format_gt_data(h5file: str, test_indices: Optional[List[int]] = None):
df = pd.read_hdf(h5file)

animals = _get_unique_level_values(df.columns, "individuals")
Expand All @@ -54,6 +55,10 @@ def _format_gt_data(h5file):
.reindex(kpts, level="bodyparts", axis=1)
)
data = temp.to_numpy().reshape((len(file_paths), len(animals), -1, 2))
if test_indices is not None:
file_paths = [file_paths[i] for i in test_indices]
data = [data[i] for i in test_indices]

meta = {"animals": animals, "keypoints": kpts, "n_unique": n_unique}
return {
"annotations": dict(zip(file_paths, data)),
Expand Down Expand Up @@ -167,16 +172,24 @@ def calc_map_from_obj(
pass
n_animals = len(df.columns.get_level_values("individuals").unique())
kpts = list(df.columns.get_level_values("bodyparts").unique())
image_paths = list(eval_results_obj)
ground_truth = (
df.loc[image_paths].to_numpy().reshape((len(image_paths), n_animals, -1, 2))
)

test_indices = _load_test_indices(metadata_file)
df_test = df.iloc[test_indices]
test_images = load_test_images(h5_file, metadata_file)
missing_images = set(test_images) - set(eval_results_obj.keys())
if len(missing_images) > 0:
raise ValueError(
"Failed to compute the test mAP: there are test images missing from the"
f"prediction object: {missing_images}"
)

ground_truth = df_test.to_numpy().reshape((len(test_images), n_animals, -1, 2))
temp = np.ones((*ground_truth.shape[:3], 3))
temp[..., :2] = ground_truth
assemblies_gt = inferenceutils._parse_ground_truth_data(temp)
with open(metadata_file, "rb") as f:
inds_test = set(pickle.load(f)[2])
assemblies_gt_test = {k: v for k, v in assemblies_gt.items() if k in inds_test}
assemblies_gt_test = {
test_images[i]: assembly
for i, assembly in inferenceutils._parse_ground_truth_data(temp).items()
}

# TODO(stes): remove/rewrite
if drop_kpts is not None:
Expand All @@ -192,9 +205,7 @@ def calc_map_from_obj(
for ind in sorted(drop_kpts, reverse=True):
kpts.pop(ind)

assemblies_pred_ = conv_obj_to_assemblies(eval_results_obj, kpts)
assemblies_pred = dict(enumerate(assemblies_pred_.values()))

assemblies_pred = conv_obj_to_assemblies(eval_results_obj, kpts)
with deeplabcut.benchmark.utils.DisableOutput():
oks = inferenceutils.evaluate_assembly(
assemblies_pred,
Expand All @@ -213,18 +224,28 @@ def calc_rmse_from_obj(
drop_kpts=None,
):
"""Calc prediction errors for submissions."""
gt = _format_gt_data(h5_file)
test_indices = _load_test_indices(metadata_file)
gt = _format_gt_data(h5_file, test_indices=test_indices)
kpts = gt["metadata"]["keypoints"]
if drop_kpts:
for k, v in gt["annotations"].items():
gt["annotations"][k] = np.delete(v, drop_kpts, axis=1)
for ind in sorted(drop_kpts, reverse=True):
kpts.pop(ind)
with open(metadata_file, "rb") as f:
inds_test = set(pickle.load(f)[2])

test_objects = {
k: v for i, (k, v) in enumerate(eval_results_obj.items()) if i in inds_test
k: v
for k, v in eval_results_obj.items()
if k in gt["annotations"].keys()
}
if len(gt["annotations"]) != len(test_objects):
gt_images = list(gt["annotations"].keys())
missing_images = [img for img in gt_images if img not in test_objects]
raise ValueError(
"Failed to compute the test RMSE: there are test images missing from the"
f"prediction object: {missing_images}"
)

assemblies_pred = conv_obj_to_assemblies(test_objects, kpts)
preds = defaultdict(dict)
preds["metadata"]["keypoints"] = kpts
Expand All @@ -240,3 +261,26 @@ def calc_rmse_from_obj(
with deeplabcut.benchmark.utils.DisableOutput():
errors = calc_prediction_errors(preds, gt)
return np.nanmean(errors[..., 0])


def load_test_images(h5file: str, metadata: str) -> List[str]:
"""
Returns the names of the test images for the benchmark, in the order corresponding
to the test indices.
"""
df = pd.read_hdf(h5file)
test_indices = _load_test_indices(metadata)
df_test = df.iloc[test_indices]
test_images = []
for img_path in df_test.index:
if not isinstance(img_path, str):
img_path = os.path.join(*img_path)
test_images.append(img_path)
return test_images


def _load_test_indices(shuffle_metadata_path: str) -> list[int]:
"""Returns the indices of test images in the training dataset dataframe"""
with open(shuffle_metadata_path, "rb") as f:
test_indices = set([int(i) for i in pickle.load(f)[2]])
return list(sorted(test_indices))
13 changes: 11 additions & 2 deletions deeplabcut/pose_estimation_tensorflow/lib/crossvalutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,11 +250,20 @@ def _benchmark_paf_graphs(
all_assemblies.append((ass.assemblies, ass.unique, ass.metadata["imnames"]))
if split_inds is not None:
oks = []

# get the indices of the images in the training set
dataset_idx = [data[image_name]["index"] for image_name in image_paths]
for inds in split_inds:
ass_gt = {k: v for k, v in ass_true_dict.items() if k in inds}
ass_gt = {
k: v for k, v in ass_true_dict.items() if dataset_idx[k] in inds
}
ass_pred = {
k: v for k, v in ass.assemblies.items() if dataset_idx[k] in inds
}

oks.append(
evaluate_assembly(
ass.assemblies,
ass_pred,
ass_gt,
oks_sigma,
margin=margin,
Expand Down