Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions deeplabcut/pose_estimation_pytorch/runners/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,21 +63,20 @@ def load_snapshot(
snapshot_path: str | Path,
device: str,
model: ModelType,
optimizer: torch.optim.Optimizer | None = None,
) -> int:
"""
) -> dict:
"""Loads the state dict for a model from a file

This method loads a file containing a DeepLabCut PyTorch model snapshot onto
a given device, and sets the model weights using the state_dict.

Args:
snapshot_path: the path containing the model weights to load
device: the device on which the model should be loaded
model: the model for which the weights are loaded
optimizer: if defined, the optimizer weights to load

Returns:
the number of epochs the model was trained for
The content of the snapshot file.
"""
snapshot = torch.load(snapshot_path, map_location=device)
model.load_state_dict(snapshot['model'])
if optimizer is not None and 'optimizer' in snapshot:
optimizer.load_state_dict(snapshot["optimizer"])

return snapshot.get("metadata", {}).get("epoch", 0)
model.load_state_dict(snapshot["model"])
return snapshot
39 changes: 35 additions & 4 deletions deeplabcut/pose_estimation_pytorch/runners/schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@

class LRListScheduler(_LRScheduler):
"""
Definition of the class object Scheduler.
You can achieve increased performance and faster training by using a learning rate that changes
during training. A scheduler makes the learning rate adaptative. Given a list of learning rates
and milestones modifies the learning rate accordingly during training
You can achieve increased performance and faster training by using a learning rate
that changes during training. A scheduler makes the learning rate adaptive. Given a
list of learning rates and milestones modifies the learning rate accordingly during
training.
"""

def __init__(self, optimizer, milestones, lr_list, last_epoch=-1) -> None:
Expand Down Expand Up @@ -78,3 +78,34 @@ def build_scheduler(
scheduler = getattr(torch.optim.lr_scheduler, scheduler_cfg["type"])

return scheduler(optimizer=optimizer, **scheduler_cfg["params"])


def load_scheduler_state(
scheduler: torch.optim.lr_scheduler.LRScheduler,
state_dict: dict,
) -> None:
"""
Args:
scheduler: The scheduler for which to load the state dict.
state_dict: The state dict to load

Raises:
ValueError: if the state dict fails to load.
"""
try:
scheduler.load_state_dict(state_dict)
except Exception as err:
raise ValueError(f"Failed to load state dict: {err}")

param_groups = scheduler.optimizer.param_groups
resume_lrs = scheduler.get_last_lr()

if len(param_groups) != len(resume_lrs):
raise ValueError(
f"Number of optimizer parameter groups ({len(param_groups)}) did not match "
f"number of learning rates to resume from ({len(scheduler.get_last_lr())})."
)

# Update the learning rate for the optimizer based on the scheduler
for group, resume_lr in zip(param_groups, resume_lrs):
group['lr'] = resume_lr
133 changes: 100 additions & 33 deletions deeplabcut/pose_estimation_pytorch/runners/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from torch.nn.parallel import DataParallel

import deeplabcut.core.metrics as metrics
import deeplabcut.pose_estimation_pytorch.runners.schedulers as schedulers
from deeplabcut.pose_estimation_pytorch.models.detectors import BaseDetector
from deeplabcut.pose_estimation_pytorch.models.model import PoseModel
from deeplabcut.pose_estimation_pytorch.runners.base import ModelType, Runner
Expand All @@ -31,47 +32,61 @@
CSVLogger,
ImageLoggerMixin,
)
from deeplabcut.pose_estimation_pytorch.runners.schedulers import build_scheduler
from deeplabcut.pose_estimation_pytorch.runners.snapshots import TorchSnapshotManager
from deeplabcut.pose_estimation_pytorch.task import Task


class TrainingRunner(Runner, Generic[ModelType], metaclass=ABCMeta):
"""Runner base class
"""Base TrainingRunner class.

A runner takes a model and runs actions on it, such as training or inference
A TrainingRunner is used to fit models to datasets. Subclasses must implement the
``step(self, batch, mode)`` method, which performs a single training or validation
step on a batch of data. The step is different depending on the model type (e.g.
a pose model step vs. an object detector step).

Args:
model: The model to fit.
optimizer: The optimizer to use to fit the model.
snapshot_manager: Manages how snapshots are saved to disk during training.
device: The device on which to run training (e.g. 'cpu', 'cuda', 'cuda:0').
gpus: Used to specify the GPU indices for multi-GPU training (e.g. [0, 1, 2, 3]
to train on 4 GPUs). When a GPUs list is given, the device must be 'cuda'.
eval_interval: The interval at which the model will be evaluated while training
(e.g. `eval_interva=5` means the model will be evaluated every 5 epochs).
snapshot_path: If continuing to train a model, the path to the snapshot to
resume training from.
scheduler: The learning rate scheduler (or it's configuration), if one should be
used.
load_scheduler_state_dict: When resuming training (snapshot_path is not None),
attempts to load the scheduler state dict from the snapshot. If you've
modified your scheduler, set this to False or the old scheduler parameters
might be used.
logger: Logger to monitor training (e.g. a WandBLogger).
log_filename: Name of the file in which to store training stats.
"""

def __init__(
self,
model: ModelType,
optimizer: torch.optim.Optimizer,
optimizer: dict | torch.optim.Optimizer,
snapshot_manager: TorchSnapshotManager,
device: str = "cpu",
gpus: list[int] | None = None,
eval_interval: int = 1,
snapshot_path: str | Path | None = None,
scheduler: torch.optim.lr_scheduler.LRScheduler | None = None,
scheduler: dict | torch.optim.lr_scheduler.LRScheduler | None = None,
load_scheduler_state_dict: bool = True,
logger: BaseLogger | None = None,
log_filename: str = "learning_stats.csv",
):
"""
Args:
model: the model to run actions on
optimizer: the optimizer to use when fitting the model
snapshot_manager: the module to use to manage snapshots
device: the device to use (e.g. {'cpu', 'cuda:0', 'mps'})
gpus: the list of GPU indices to use for multi-GPU training
eval_interval: how often evaluation is run on the test set (in epochs)
snapshot_path: if defined, the path of a snapshot from which to load
pretrained weights
scheduler: scheduler for adjusting the lr of the optimizer
logger: logger to monitor training (e.g WandB logger)
log_filename: name of the file in which to store training stats
"""
super().__init__(
model=model, device=device, gpus=gpus, snapshot_path=snapshot_path
)
if isinstance(optimizer, dict):
optimizer = build_optimizer(model, optimizer)
if isinstance(scheduler, dict):
scheduler = schedulers.build_scheduler(scheduler, optimizer)

self.eval_interval = eval_interval
self.optimizer = optimizer
self.scheduler = scheduler
Expand All @@ -88,28 +103,34 @@ def __init__(
# some models cannot compute a validation loss (e.g. detectors)
self._print_valid_loss = True

if self.snapshot_path is not None and self.snapshot_path != "":
self.starting_epoch = self.load_snapshot(
self.snapshot_path,
self.device,
self.model,
self.optimizer,
)
if self.snapshot_path:
snapshot = self.load_snapshot(self.snapshot_path, self.device, self.model)
self.starting_epoch = snapshot.get("metadata", {}).get("epoch", 0)

if "optimizer" in snapshot:
self.optimizer.load_state_dict(snapshot["optimizer"])

self._load_scheduler_state_dict(load_scheduler_state_dict, snapshot)

self._metadata = dict(epoch=self.starting_epoch, metrics=dict(), losses=dict())
self._epoch_ground_truth = {}
self._epoch_predictions = {}

def state_dict(self) -> dict:
"""Returns: the state dict for the runner"""
model = self.model
if self._data_parallel:
model = self.model.module

return {
"metadata": self._metadata,
"model": model.state_dict(),
"optimizer": self.optimizer.state_dict(),
}
state_dict_ = dict(
metadata=self._metadata,
model=model.state_dict(),
optimizer=self.optimizer.state_dict(),
)
if self.scheduler is not None:
state_dict_["scheduler"] = self.scheduler.state_dict()

return state_dict_

@abstractmethod
def step(
Expand Down Expand Up @@ -256,7 +277,9 @@ def _epoch(

if len(epoch_loss) > 0:
epoch_loss = np.mean(epoch_loss).item()
self.history[f"{mode}_loss"].append(epoch_loss)
else:
epoch_loss = 0
self.history[f"{mode}_loss"].append(epoch_loss)

metrics_to_log = {}
if perf_metrics:
Expand All @@ -279,6 +302,29 @@ def _epoch(

return epoch_loss

def _load_scheduler_state_dict(self, load_state_dict: bool, snapshot: dict) -> None:
if self.scheduler is None:
return

loaded_state_dict = False
if load_state_dict and "scheduler" in snapshot:
try:
schedulers.load_scheduler_state(self.scheduler, snapshot["scheduler"])
loaded_state_dict = True
except ValueError as err:
logging.warning(
"Failed to load the scheduler state_dict. The scheduler will "
"restart at epoch 0. This is expected if the scheduler "
"configuration was edited since the original snapshot was "
f"trained. Error: {err}"
)

if not loaded_state_dict and self.starting_epoch > 0:
logging.info(
f"Setting the scheduler starting epoch to {self.starting_epoch}"
)
self.scheduler.last_epoch = self.starting_epoch


class PoseTrainingRunner(TrainingRunner[PoseModel]):
"""Runner to train pose estimation models"""
Expand Down Expand Up @@ -596,11 +642,13 @@ def build_training_runner(
optim_cfg = runner_config["optimizer"]
optim_cls = getattr(torch.optim, optim_cfg["type"])
optimizer = optim_cls(params=model.parameters(), **optim_cfg["params"])
scheduler = build_scheduler(runner_config.get("scheduler"), optimizer)
scheduler = schedulers.build_scheduler(runner_config.get("scheduler"), optimizer)

# if no custom snapshot prefix is defined, use the default one
snapshot_prefix = runner_config.get("snapshot_prefix")
if snapshot_prefix is None or len(snapshot_prefix) == 0:
snapshot_prefix = task.snapshot_prefix

kwargs = dict(
model=model,
optimizer=optimizer,
Expand All @@ -618,9 +666,28 @@ def build_training_runner(
eval_interval=runner_config.get("eval_interval"),
snapshot_path=snapshot_path,
scheduler=scheduler,
load_scheduler_state_dict=runner_config.get("load_scheduler_state_dict", True),
logger=logger,
)
if task == Task.DETECT:
return DetectorTrainingRunner(**kwargs)

return PoseTrainingRunner(**kwargs)


def build_optimizer(
model: nn.Module,
optimizer_config: dict,
) -> torch.optim.Optimizer:
"""Builds an optimizer from a configuration.

Args:
model: The model to optimize.
optimizer_config: The configuration for the optimizer.

Returns:
The optimizer for the model built according to the given configuration.
"""
optim_cls = getattr(torch.optim, optimizer_config["type"])
optimizer = optim_cls(params=model.parameters(), **optimizer_config["params"])
return optimizer
15 changes: 14 additions & 1 deletion docs/pytorch/pytorch_config.md
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ runner:
...
scheduler: # optional: a learning rate scheduler
...
load_scheduler_state_dict: true/false # whether to load scheduler state when resuming training from a snapshot,
snapshots: # parameters for the TorchSnapshotManager
max_snapshots: 5 # the maximum number of snapshots to save (the "best" model does not count as one of them)
save_epochs: 25 # the interval between each snapshot save
Expand Down Expand Up @@ -327,7 +328,7 @@ https://pytorch.org/docs/stable/optim.html). Examples:
lr: 1e-4
```

**Scheduler**: YYou can use [any scheduler](
**Scheduler**: You can use [any scheduler](
https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate) defined in
`torch.optim.lr_scheduler`, where the arguments given are arguments of the scheduler.
The default scheduler is an LRListScheduler, which changes the learning rates at each
Expand Down Expand Up @@ -410,6 +411,12 @@ continue to train from the 10th epoch on.
resume_training_from: /Users/john/dlc-project-2021-06-22/dlc-models-pytorch/iteration-0/dlcJun22-trainset95shuffle0/train/snapshot-010.pt
```

When continuing to train a model, you may want to modify the learning rate scheduling
that was being used (by editing the configuration under the `scheduler` key). When doing
so, you *must set `load_scheduler_state_dict: false`* in your `runner` config!
Otherwise, the parameters for the scheduler your started training with will be loaded
from the state dictionary, and your edits might not be kept!

## Training Top-Down Models

Top-down models are split into two main elements: a detector (localizing individuals in
Expand Down Expand Up @@ -479,3 +486,9 @@ detector:
# weights from which to resume training
resume_training_from: /Users/john/dlc-project-2021-06-22/dlc-models-pytorch/iteration-0/dlcJun22-trainset95shuffle0/train/snapshot-detector-020.pt
```

When continuing to train a detector, you may want to modify the learning rate scheduling
that was being used (by editing the configuration under the `scheduler` key). When doing
so, you *must set `load_scheduler_state_dict: false`* in your `detector`: `runner`
config! Otherwise, the parameters for the scheduler your started training with will be
loaded from the state dictionary, and your edits might not be kept!
Loading