Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 206 additions & 3 deletions deeplabcut/pose_estimation_pytorch/apis/evaluate.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import Iterable

import albumentations as A
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from tqdm import tqdm
Expand All @@ -35,8 +36,15 @@
from deeplabcut.pose_estimation_pytorch.runners import InferenceRunner
from deeplabcut.pose_estimation_pytorch.runners.snapshots import Snapshot
from deeplabcut.pose_estimation_pytorch.task import Task
from deeplabcut.utils import auxiliaryfunctions
from deeplabcut.utils.visualization import plot_evaluation_results
from deeplabcut.utils import auxfun_videos, auxiliaryfunctions
from deeplabcut.utils.visualization import (
create_minimal_figure,
erase_artists,
get_cmap,
make_multianimal_labeled_image,
plot_evaluation_results,
save_labeled_frame,
)


def predict(
Expand Down Expand Up @@ -167,6 +175,200 @@ def evaluate(
return results, predictions


def visualize_predictions(
predictions: dict,
ground_truth: dict,
output_dir: str | Path | None = None,
num_samples: int | None = None,
random_select: bool = False,
show_ground_truth: bool = True,
) -> None:
"""Visualize model predictions alongside ground truth keypoints.

This function processes keypoint predictions and ground truth data, applies visibility
masks, and generates visualization plots. It supports random or sequential sampling
of images for visualization.

Args:
predictions: Dictionary mapping image paths to prediction data.
Each prediction contains:
- bodyparts: array of shape [N, num_keypoints, 3] where 3 represents (x, y, confidence)
- bboxes: array of shape [N, 4] for bounding boxes (optional)
- bbox_scores: array of shape [N,] for bbox confidences (optional)

ground_truth: Dictionary mapping image paths to ground truth keypoints.
Each value has shape [N, num_keypoints, 3] where 3 represents (x, y, visibility)

output_dir: Path to save visualization outputs.
Defaults to "predictions_visualizations"

num_samples: Number of images to visualize. If None, processes all images

random_select: If True, randomly samples images; if False, uses first N images

show_ground_truth: If True, displays ground truth poses alongside predictions.
If False, only shows predictions but uses GT visibility mask
"""
# Setup output directory
output_dir = Path(output_dir or "predictions_visualizations")
output_dir.mkdir(exist_ok=True)

# Select images to process
image_paths = list(predictions.keys())
if num_samples and num_samples < len(image_paths):
if random_select:
image_paths = np.random.choice(
image_paths, num_samples, replace=False
).tolist()
else:
image_paths = image_paths[:num_samples]

# Process each selected image
for image_path in image_paths:
# Get prediction and ground truth data
pred_data = predictions[image_path]
gt_keypoints = ground_truth[image_path] # Shape: [N, num_keypoints, 3]

# Create visibility mask from first GT sample. This mask will be applied to all samples for consistency
vis_mask = gt_keypoints[0, :, 2] > 0

# Process ground truth keypoints if showing GT
if show_ground_truth:
visible_gt = []
for gt in gt_keypoints:
visible_points = gt[vis_mask, :2] # Keep only x,y for visible joints
visible_gt.append(visible_points)
visible_gt = np.stack(visible_gt) # Shape: [N, num_visible_joints, 2]
else:
visible_gt = None

# Process predicted keypoints
pred_keypoints = pred_data["bodyparts"] # Shape: [N, num_keypoints, 3]
visible_pred = []
for pred in pred_keypoints:
visible_points = pred[vis_mask] # Keep only visible joint predictions
visible_pred.append(visible_points)
visible_pred = np.stack(visible_pred) # Shape: [N, num_visible_joints, 3]

# Generate and save visualization
try:
plot_gt_and_predictions(
image_path=image_path,
output_dir=output_dir,
gt_bodyparts=visible_gt,
pred_bodyparts=visible_pred,
)
print(f"Successfully plotted predictions for {image_path}")
except Exception as e:
print(f"Error plotting predictions for {image_path}: {str(e)}")


def plot_gt_and_predictions(
image_path: str | Path,
output_dir: str | Path,
gt_bodyparts: np.ndarray,
pred_bodyparts: np.ndarray,
gt_unique_bodyparts: np.ndarray | None = None,
pred_unique_bodyparts: np.ndarray | None = None,
mode: str = "bodypart",
colormap: str = "rainbow",
dot_size: int = 12,
alpha_value: float = 0.7,
p_cutoff: float = 0.6,
):
"""Plot ground truth and predictions on an image.

Args:
image_path: Path to the image
gt_bodyparts: Ground truth keypoints array (num_animals, num_keypoints, 3)
pred_bodyparts: Predicted keypoints array (num_animals, num_keypoints, 3)
output_dir: Directory where labeled images will be saved
gt_unique_bodyparts: Ground truth unique bodyparts if any
pred_unique_bodyparts: Predicted unique bodyparts if any
mode: How to color the points ("bodypart" or "individual")
colormap: Matplotlib colormap name
dot_size: Size of the plotted points
alpha_value: Transparency of the points
p_cutoff: Confidence threshold for showing predictions
"""
# Ensure output directory exists
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)

# Read the image
frame = auxfun_videos.imread(str(image_path), mode="skimage")
num_pred, num_keypoints = pred_bodyparts.shape[:2]

# Create figure and set dimensions
fig, ax = create_minimal_figure()
h, w, _ = np.shape(frame)
fig.set_size_inches(w / 100, h / 100)
ax.set_xlim(0, w)
ax.set_ylim(0, h)
ax.invert_yaxis()
ax.imshow(frame, "gray")

# Set up colors based on mode
if mode == "bodypart":
num_colors = num_keypoints
if pred_unique_bodyparts is not None:
num_colors += pred_unique_bodyparts.shape[1]
colors = get_cmap(num_colors, name=colormap)

predictions = pred_bodyparts.swapaxes(0, 1)
ground_truth = gt_bodyparts.swapaxes(0, 1)
elif mode == "individual":
colors = get_cmap(num_pred + 1, name=colormap)
predictions = pred_bodyparts
ground_truth = gt_bodyparts
else:
raise ValueError(f"Invalid mode: {mode}")

# Plot regular bodyparts
ax = make_multianimal_labeled_image(
frame,
ground_truth,
predictions[:, :, :2],
predictions[:, :, 2:],
colors,
dot_size,
alpha_value,
p_cutoff,
ax=ax,
)

# Plot unique bodyparts if present
if pred_unique_bodyparts is not None and gt_unique_bodyparts is not None:
if mode == "bodypart":
unique_predictions = pred_unique_bodyparts.swapaxes(0, 1)
unique_ground_truth = gt_unique_bodyparts.swapaxes(0, 1)
else:
unique_predictions = pred_unique_bodyparts
unique_ground_truth = gt_unique_bodyparts

ax = make_multianimal_labeled_image(
frame,
unique_ground_truth,
unique_predictions[:, :, :2],
unique_predictions[:, :, 2:],
colors[num_keypoints:],
dot_size,
alpha_value,
p_cutoff,
ax=ax,
)

# Save the labeled image
save_labeled_frame(
fig,
str(image_path),
str(output_dir),
belongs_to_train=False,
)
erase_artists(ax)
plt.close()


def evaluate_snapshot(
cfg: dict,
loader: DLCLoader,
Expand Down Expand Up @@ -222,7 +424,7 @@ def evaluate_snapshot(
parameters = PoseDatasetParameters(
bodyparts=project_bodyparts,
unique_bpts=parameters.unique_bpts,
individuals=parameters.individuals
individuals=parameters.individuals,
)

predictions = {}
Expand Down Expand Up @@ -289,6 +491,7 @@ def evaluate_snapshot(
df_ground_truth, left_index=True, right_index=True
)
unique_bodyparts = loader.get_dataset_parameters().unique_bpts

plot_evaluation_results(
df_combined=df_combined,
project_root=cfg["project_path"],
Expand Down