Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 20 additions & 10 deletions deeplabcut/modelzoo/video_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def video_inference_superanimal(
NotImplementedError:
If the model is not found in the modelzoo.
Warning: If the superanimal_name will be deprecated in the future.

(Model Explanation) SuperAnimal-Quadruped:
`superanimal_quadruped` models aim to work across a large range of quadruped
animals, from horses, dogs, sheep, rodents, to elephants. The camera perspective is
Expand Down Expand Up @@ -276,9 +276,7 @@ def video_inference_superanimal(
_video_inference_superanimal,
)

weight_folder = (
get_snapshot_folder_path() / f"{superanimal_name}_{model_name}"
)
weight_folder = get_snapshot_folder_path() / f"{superanimal_name}_{model_name}"
if not weight_folder.exists():
download_huggingface_model(
superanimal_name, target_dir=str(weight_folder), rename_mapping=None
Expand All @@ -299,6 +297,11 @@ def video_inference_superanimal(
pseudo_threshold,
)
elif framework == "pytorch":
if detector_name is None:
raise ValueError(
"You have to specify a detector_name when using the Pytorch framework."
)

from deeplabcut.pose_estimation_pytorch.modelzoo.inference import (
_video_inference_superanimal,
)
Expand Down Expand Up @@ -373,10 +376,9 @@ def video_inference_superanimal(
video_to_frames(video_path, pseudo_dataset_folder, cropping=cropping)

anno_folder = pseudo_dataset_folder / "annotations"
if (
(anno_folder / "train.json").exists()
and (anno_folder / "test.json").exists()
):
if (anno_folder / "train.json").exists() and (
anno_folder / "test.json"
).exists():
print(
f"{anno_folder} exists, skipping the annotation construction. "
f"Delete the folder if you want to re-construct pseudo annotations"
Expand Down Expand Up @@ -405,16 +407,24 @@ def video_inference_superanimal(
bbox_threshold=bbox_threshold,
)

model_snapshot_prefix = f"snapshot-{model_name}"
detector_snapshot_prefix = f"snapshot-{detector_name}"

config["runner"]["snapshot_prefix"] = model_snapshot_prefix
config["detector"]["runner"]["snapshot_prefix"] = detector_snapshot_prefix

# the model config's parameters need to be updated for adaptation training
model_config_path = model_folder / "pytorch_config.yaml"
with open(model_config_path, "w") as f:
yaml = YAML()
yaml.dump(config, f)

adapted_detector_checkpoint = (
model_folder / f"snapshot-detector-{detector_epochs:03}.pt"
model_folder / f"{detector_snapshot_prefix}-{detector_epochs:03}.pt"
)
adapted_pose_checkpoint = (
model_folder / f"{model_snapshot_prefix}-{pose_epochs:03}.pt"
)
adapted_pose_checkpoint = model_folder / f"snapshot-{pose_epochs:03}.pt"

if (
adapted_detector_checkpoint.exists()
Expand Down
4 changes: 3 additions & 1 deletion deeplabcut/pose_estimation_pytorch/apis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,9 @@ def get_model_snapshots(
ValueError: If the index given is not valid
ValueError: If index=="best" but there is no saved best model
"""
snapshot_manager = TorchSnapshotManager(model_folder=model_folder, task=task)
snapshot_manager = TorchSnapshotManager(
model_folder=model_folder, snapshot_prefix=task.snapshot_prefix
)
if isinstance(index, str) and index.lower() == "best":
best_snapshot = snapshot_manager.best()
if best_snapshot is None:
Expand Down
12 changes: 5 additions & 7 deletions deeplabcut/pose_estimation_pytorch/runners/snapshots.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
import numpy as np
import torch

from deeplabcut.pose_estimation_pytorch.task import Task


@dataclass(frozen=True)
class Snapshot:
Expand All @@ -43,7 +41,7 @@ class TorchSnapshotManager:
"""Class handling model checkpoint I/O

Attributes:
task: The task that the model is performing.
snapshot_prefix: The prefix to use when saving snapshots.
model_folder: The path to the directory where model snapshots should be stored.
key_metric: If defined, the metric is used to save the best model. Otherwise no
best model is used.
Expand All @@ -60,7 +58,7 @@ class TorchSnapshotManager:
model: nn.Module
loader = DLCLoader(...)
snapshot_manager = TorchSnapshotManager(
Task.BOTTOM_UP,
"snapshot",
loader.model_folder,
key_metric="test.mAP",
)
Expand All @@ -76,7 +74,7 @@ class TorchSnapshotManager:
})
"""

task: Task
snapshot_prefix: str
model_folder: Path
key_metric: str | None = None
key_metric_asc: bool = True
Expand Down Expand Up @@ -191,7 +189,7 @@ def _sort_key(snapshot: Snapshot) -> int:
def _sort_key_best_as_last(snapshot: Snapshot) -> tuple[int, int]:
return 1 if snapshot.best else 0, snapshot.epochs

pattern = r"^(" + self.task.snapshot_prefix + r"(-best)?-\d+\.pt)$"
pattern = r"^(" + self.snapshot_prefix + r"(-best)?-\d+\.pt)$"
snapshots = [
Snapshot.from_path(f)
for f in self.model_folder.iterdir()
Expand All @@ -216,4 +214,4 @@ def snapshot_path(self, epoch: int, best: bool = False) -> Path:
uid = f"{epoch:03}"
if best:
uid = f"best-{uid}"
return self.model_folder / f"{self.task.snapshot_prefix}-{uid}.pt"
return self.model_folder / f"{self.snapshot_prefix}-{uid}.pt"
6 changes: 5 additions & 1 deletion deeplabcut/pose_estimation_pytorch/runners/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,11 +597,15 @@ def build_training_runner(
optim_cls = getattr(torch.optim, optim_cfg["type"])
optimizer = optim_cls(params=model.parameters(), **optim_cfg["params"])
scheduler = build_scheduler(runner_config.get("scheduler"), optimizer)
# if no custom snapshot prefix is defined, use the default one
snapshot_prefix = runner_config.get("snapshot_prefix")
if snapshot_prefix is None or len(snapshot_prefix) == 0:
snapshot_prefix = task.snapshot_prefix
kwargs = dict(
model=model,
optimizer=optimizer,
snapshot_manager=TorchSnapshotManager(
task=task,
snapshot_prefix=snapshot_prefix,
model_folder=model_folder,
key_metric=runner_config.get("key_metric"),
key_metric_asc=runner_config.get("key_metric_asc"),
Expand Down