Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 37 additions & 7 deletions deeplabcut/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def get_available_aug_methods(engine: Engine) -> tuple[str, ...]:
if engine == Engine.TF:
return "imgaug", "default", "deterministic", "scalecrop", "tensorpack"
elif engine == Engine.PYTORCH:
return ("albumentations", )
return ("albumentations",)

raise RuntimeError(f"Unknown augmentation for engine: {engine}")

Expand Down Expand Up @@ -218,6 +218,7 @@ def train_network(

if engine == Engine.TF:
from deeplabcut.pose_estimation_tensorflow import train_network

if max_snapshots_to_keep is None:
max_snapshots_to_keep = 5

Expand All @@ -239,6 +240,7 @@ def train_network(
)
elif engine == Engine.PYTORCH:
from deeplabcut.pose_estimation_pytorch.apis import train_network

_update_device(gputouse, torch_kwargs)
if "display_iters" not in torch_kwargs:
torch_kwargs["display_iters"] = displayiters
Expand Down Expand Up @@ -299,14 +301,18 @@ def return_train_network_path(

if engine == Engine.TF:
from deeplabcut.pose_estimation_tensorflow import return_train_network_path

return return_train_network_path(
config,
shuffle=shuffle,
trainingsetindex=trainingsetindex,
modelprefix=modelprefix,
)
elif engine == Engine.PYTORCH:
from deeplabcut.pose_estimation_pytorch.apis.utils import return_train_network_path
from deeplabcut.pose_estimation_pytorch.apis.utils import (
return_train_network_path,
)

return return_train_network_path(
config,
shuffle=shuffle,
Expand Down Expand Up @@ -458,6 +464,7 @@ def evaluate_network(

if engine == Engine.TF:
from deeplabcut.pose_estimation_tensorflow import evaluate_network

return evaluate_network(
str(config),
Shuffles=Shuffles,
Expand All @@ -473,6 +480,7 @@ def evaluate_network(
)
elif engine == Engine.PYTORCH:
from deeplabcut.pose_estimation_pytorch.apis import evaluate_network

_update_device(gputouse, torch_kwargs)
return evaluate_network(
config,
Expand Down Expand Up @@ -553,6 +561,7 @@ def return_evaluate_network_data(

if engine == Engine.TF:
from deeplabcut.pose_estimation_tensorflow import return_evaluate_network_data

return return_evaluate_network_data(
config,
shuffle=shuffle,
Expand Down Expand Up @@ -816,6 +825,7 @@ def analyze_videos(

if engine == Engine.TF:
from deeplabcut.pose_estimation_tensorflow import analyze_videos

kwargs = {}
if use_openvino is not None: # otherwise default comes from tensorflow API
kwargs["use_openvino"] = use_openvino
Expand Down Expand Up @@ -846,6 +856,7 @@ def analyze_videos(
)
elif engine == Engine.PYTORCH:
from deeplabcut.pose_estimation_pytorch.apis import analyze_videos

_update_device(gputouse, torch_kwargs)

if batchsize is not None:
Expand Down Expand Up @@ -907,6 +918,7 @@ def create_tracking_dataset(

if engine == Engine.TF:
from deeplabcut.pose_estimation_tensorflow import create_tracking_dataset

return create_tracking_dataset(
config,
videos,
Expand Down Expand Up @@ -994,6 +1006,7 @@ def analyze_time_lapse_frames(

if engine == Engine.TF:
from deeplabcut.pose_estimation_tensorflow import analyze_time_lapse_frames

return analyze_time_lapse_frames(
config,
directory,
Expand Down Expand Up @@ -1120,6 +1133,7 @@ def convert_detections2tracklets(

if engine == Engine.TF:
from deeplabcut.pose_estimation_tensorflow import convert_detections2tracklets

return convert_detections2tracklets(
config,
videos,
Expand Down Expand Up @@ -1233,6 +1247,7 @@ def extract_maps(

if engine == Engine.TF:
from deeplabcut.pose_estimation_tensorflow import extract_maps

return extract_maps(
config,
shuffle=shuffle,
Expand All @@ -1244,6 +1259,7 @@ def extract_maps(
)
elif engine == Engine.PYTORCH:
from deeplabcut.pose_estimation_pytorch import extract_maps

return extract_maps(
config,
shuffle=shuffle,
Expand Down Expand Up @@ -1404,6 +1420,7 @@ def extract_save_all_maps(

if engine == Engine.TF:
from deeplabcut.pose_estimation_tensorflow import extract_save_all_maps

return extract_save_all_maps(
config,
shuffle=shuffle,
Expand All @@ -1419,6 +1436,7 @@ def extract_save_all_maps(
)
elif engine == Engine.PYTORCH:
from deeplabcut.pose_estimation_pytorch import extract_save_all_maps

return extract_save_all_maps(
config,
shuffle=shuffle,
Expand Down Expand Up @@ -1450,14 +1468,12 @@ def export_model(
wipepaths: bool = False,
modelprefix: str = "",
engine: Engine | None = None,
):
) -> None:
"""Export DeepLabCut models for the model zoo or for live inference.

Saves the pose configuration, snapshot files, and frozen TF graph of the model to
directory named exported-models within the project directory

This function is only implemented for tensorflow models/shuffles, and will throw
an error if called with a PyTorch shuffle.
directory named exported-models within the project directory (and an
`exported-models-pytorch` directory for PyTorch models).

Parameters
-----------
Expand Down Expand Up @@ -1514,6 +1530,7 @@ def export_model(

if engine == Engine.TF:
from deeplabcut.pose_estimation_tensorflow import export_model

return export_model(
cfg_path=cfg_path,
shuffle=shuffle,
Expand All @@ -1526,6 +1543,19 @@ def export_model(
wipepaths=wipepaths,
modelprefix=modelprefix,
)
elif engine == Engine.PYTORCH:
from deeplabcut.pose_estimation_pytorch.apis.export import export_model

return export_model(
config=cfg_path,
shuffle=shuffle,
trainingsetindex=trainingsetindex,
snapshotindex=snapshotindex,
iteration=iteration,
overwrite=overwrite,
wipe_paths=wipepaths,
modelprefix=modelprefix,
)

raise NotImplementedError(f"This function is not implemented for {engine}")

Expand Down
1 change: 1 addition & 0 deletions deeplabcut/pose_estimation_pytorch/apis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
convert_detections2tracklets,
)
from deeplabcut.pose_estimation_pytorch.apis.evaluate import evaluate_network
from deeplabcut.pose_estimation_pytorch.apis.export import export_model
from deeplabcut.pose_estimation_pytorch.apis.train import train_network
from deeplabcut.pose_estimation_pytorch.apis.visualization import (
extract_maps,
Expand Down
191 changes: 191 additions & 0 deletions deeplabcut/pose_estimation_pytorch/apis/export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
#
# DeepLabCut Toolbox (deeplabcut.org)
# © A. & M.W. Mathis Labs
# https://github.com/DeepLabCut/DeepLabCut
#
# Please see AUTHORS for contributors.
# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
#
# Licensed under GNU Lesser General Public License v3.0
#
"""Code to export DeepLabCut models for DLCLive inference"""
import copy
from pathlib import Path

import torch

import deeplabcut.pose_estimation_pytorch.apis.utils as utils
import deeplabcut.pose_estimation_pytorch.data as dlc3_data
import deeplabcut.utils.auxiliaryfunctions as af
from deeplabcut.pose_estimation_pytorch.runners.snapshots import Snapshot
from deeplabcut.pose_estimation_pytorch.task import Task


def export_model(
config: str | Path,
shuffle: int = 1,
trainingsetindex: int = 0,
snapshotindex: int | None = None,
detector_snapshot_index: int | None = None,
iteration: int | None = None,
overwrite: bool = False,
wipe_paths: bool = False,
modelprefix: str | None = None,
) -> None:
"""Export DeepLabCut models for live inference.

Saves the pytorch_config.yaml configuration, snapshot files, of the model to a
directory named exported-models-pytorch within the project directory.

Args:
config: Path of the project configuration file
shuffle : The shuffle of the model to export.
trainingsetindex: The index of the training fraction for the model you wish to
export.
snapshotindex: The snapshot index for the weights you wish to export. If None,
uses the snapshotindex as defined in ``config.yaml``.
detector_snapshot_index: Only for TD models. If defined, uses the detector with
the given index for pose estimation. If None, uses the snapshotindex as
defined in the project ``config.yaml``.
iteration: The project iteration (active learning loop) you wish to export. If
None, the iteration listed in the project config file is used.
overwrite : bool, optional
If the model you wish to export has already been exported, whether to
overwrite. default = False
wipe_paths : bool, optional
Removes the actual path of your project and the init_weights from the
``pytorch_config.yaml``.
modelprefix: Directory containing the deeplabcut models to use when evaluating
the network. By default, the models are assumed to exist in the project
folder.

Raises:
ValueError: If no snapshots could be found for the shuffle.
ValueError: If a top-down model is exported but no detector snapshots are found.

Examples:
Export the last stored snapshot for model trained with shuffle 3:
>>> import deeplabcut
>>> deeplabcut.export_model(
>>> "/analysis/project/reaching-task/config.yaml",
>>> shuffle=3,
>>> snapshotindex=-1,
>>> )
"""
cfg = af.read_config(str(config))
if iteration is not None:
cfg["iteration"] = iteration

loader = dlc3_data.DLCLoader(
config=cfg,
trainset_index=trainingsetindex,
shuffle=shuffle,
modelprefix="" if modelprefix is None else modelprefix,
)

if snapshotindex is None:
snapshotindex = loader.project_cfg["snapshotindex"]
snapshots = utils.get_model_snapshots(
snapshotindex, loader.model_folder, loader.pose_task
)

if len(snapshots) == 0:
raise ValueError(
f"Could not find any snapshots to export in ``{loader.model_folder}`` for "
f"``snapshotindex={snapshotindex}``."
)

detector_snapshots = [None]
if loader.pose_task == Task.TOP_DOWN:
if detector_snapshot_index is None:
detector_snapshot_index = loader.project_cfg["detector_snapshotindex"]
detector_snapshots = utils.get_model_snapshots(
detector_snapshot_index, loader.model_folder, Task.DETECT
)

if len(detector_snapshots) == 0:
raise ValueError(
"Attempting to export a top-down pose estimation model but no detector "
f"snapshots were found in ``{loader.model_folder}`` for "
f"``detector_snapshot_index={detector_snapshot_index}``. You must "
f"export a detector snapshot with a top-down pose estimation model."
)

export_folder_name = get_export_folder_name(loader)
export_dir = loader.project_path / "exported-models-pytorch" / export_folder_name
export_dir.mkdir(exist_ok=True, parents=True)

load_kwargs = dict(map_location="cpu", weights_only=True)
for det_snapshot in detector_snapshots:
detector_weights = None
if det_snapshot is not None:
detector_weights = torch.load(det_snapshot.path, **load_kwargs)["model"]

for snapshot in snapshots:
export_filename = get_export_filename(loader, snapshot, det_snapshot)
export_path = export_dir / export_filename
if export_path.exists() and not overwrite:
continue

model_cfg = copy.deepcopy(loader.model_cfg)
if wipe_paths:
wipe_paths_from_model_config(model_cfg)

pose_weights = torch.load(snapshot.path, **load_kwargs)["model"]
export_dict = dict(config=model_cfg, pose=pose_weights)
if detector_weights is not None:
export_dict["detector"] = detector_weights

torch.save(export_dict, export_path)


def get_export_folder_name(loader: dlc3_data.DLCLoader) -> str:
"""
Args:
loader: The loader for the shuffle for which we want to export models.

Returns:
The name of the folder in which exported models should be placed for a shuffle.
"""
return (
f"DLC_{loader.project_cfg['Task']}_{loader.model_cfg['net_type']}_"
f"iteration-{loader.project_cfg['iteration']}_shuffle-{loader.shuffle}"
)


def get_export_filename(
loader: dlc3_data.DLCLoader,
snapshot: Snapshot,
detector_snapshot: Snapshot | None = None,
) -> str:
"""
Args:
loader: The loader for the shuffle for which we want to export models.
snapshot: The pose model snapshot to export.
detector_snapshot: The detector snapshot to export, for top-down models.

Returns:
The name of the file in which the exported model should be stored.
"""
export_filename = get_export_folder_name(loader)
if detector_snapshot is not None:
export_filename += "_snapshot-detector" + detector_snapshot.uid()
export_filename += "_snapshot-" + snapshot.uid()
return export_filename + ".pt"


def wipe_paths_from_model_config(model_cfg: dict) -> None:
"""
Removes all paths from the contents of the ``pytorch_config`` file.

Args:
model_cfg: The model configuration to wipe.
"""
model_cfg["metadata"]["project_path"] = ""
model_cfg["metadata"]["pose_config_path"] = ""
if "weight_init" in model_cfg["train_settings"]:
model_cfg["train_settings"]["weight_init"] = None
if "resume_training_from" in model_cfg:
model_cfg["resume_training_from"] = None
if "resume_training_from" in model_cfg.get("detector", {}):
model_cfg["detector"]["resume_training_from"] = None
Loading