Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 94 additions & 3 deletions tests/test_auxiliaryfunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,13 @@
#
# Licensed under GNU Lesser General Public License v3.0
#
from pathlib import Path
import pytest
from deeplabcut.utils import auxiliaryfunctions
from deeplabcut.utils.auxfun_videos import SUPPORTED_VIDEOS


def test_find_analyzed_data(tmpdir_factory):
import os
import pytest

fake_folder = tmpdir_factory.mktemp("videos")
SUPPORTED_VIDEOS = ["avi"]
n_ext = len(SUPPORTED_VIDEOS)
Expand Down Expand Up @@ -119,3 +118,95 @@ def _create_fake_file(filename):
videotype=ext,
)
assert len(videos) == 1


def test_write_config_has_skeleton(tmpdir_factory):
""" Required for backward compatibility """
fake_folder = tmpdir_factory.mktemp("fakeConfigs")
fake_config_file = fake_folder / Path("fakeConfig")
auxiliaryfunctions.write_config(fake_config_file, {})
config_data = auxiliaryfunctions.read_config(fake_config_file)
assert "skeleton" in config_data


@pytest.mark.parametrize(
"multianimal, bodyparts, ma_bpts, unique_bpts, comparison_bpts, expected_bpts",
[
(False, ["head", "shoulders", "knees", "toes"], None, None, {"knees", "others", "toes"}, ["knees", "toes"]),
(True, None, ["head", "shoulders", "knees"], ["toes"], {"knees", "others", "toes"}, ["knees", "toes"]),
]
)
def test_intersection_of_body_parts_and_ones_given_by_user(
multianimal, bodyparts, ma_bpts, unique_bpts, comparison_bpts, expected_bpts
):
cfg = {
"multianimalproject": multianimal,
"bodyparts": bodyparts,
"multianimalbodyparts": ma_bpts,
"uniquebodyparts": unique_bpts,
}

if multianimal:
all_bodyparts = list(set(ma_bpts + unique_bpts))
else:
all_bodyparts = bodyparts

filtered_bpts = auxiliaryfunctions.intersection_of_body_parts_and_ones_given_by_user(
cfg, comparisonbodyparts="all"
)
print(all_bodyparts)
print(filtered_bpts)
assert len(all_bodyparts) == len(filtered_bpts)
assert all([bpt in all_bodyparts for bpt in filtered_bpts])

filtered_bpts = auxiliaryfunctions.intersection_of_body_parts_and_ones_given_by_user(
cfg, comparisonbodyparts=comparison_bpts,
)
print(filtered_bpts)
assert len(expected_bpts) == len(filtered_bpts)
assert all([bpt in expected_bpts for bpt in filtered_bpts])


class MockPath:

def __init__(self, path: Path, st_mtime: int):
self.path = path
self.parent = self.path.parent
self.st_mtime = st_mtime

def lstat(self):
return self



# labeled_folders: (has_H5, H5_st_mtime, folder_name)
@pytest.mark.parametrize(
"labeled_folders, next_folder_name",
[
([(True, 1, "a"), (False, None, "b"), (False, None, "c")], "b"),
([(False, None, "a"), (True, 123, "d"), (False, None, "f")], "f"),
]
)
def test_find_next_unlabeled_folder(
tmpdir_factory, monkeypatch, labeled_folders, next_folder_name,
):
project_folder = tmpdir_factory.mktemp("project")
fake_cfg = Path(project_folder / "cfg.yaml")
auxiliaryfunctions.write_config(fake_cfg, {"project_path": str(project_folder)})

data_folder = project_folder / "labeled-data"
data_folder.mkdir()
rglob_results = []
for has_h5, h5_last_mod_time, folder_name in labeled_folders:
labeled_folder_path = Path(data_folder / folder_name)
labeled_folder_path.mkdir()
if has_h5:
h5_path = Path(labeled_folder_path / "data.h5")
rglob_results.append(MockPath(h5_path, h5_last_mod_time))

def get_rglob_results(*args, **kwargs):
return rglob_results

monkeypatch.setattr(Path, "rglob", get_rglob_results)
next_folder = auxiliaryfunctions.find_next_unlabeled_folder(fake_cfg)
assert str(next_folder) == str(Path(data_folder / next_folder_name))
75 changes: 75 additions & 0 deletions tests/test_frame_selection_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
""" Tests for frame selection tools """
import math
from unittest.mock import Mock
import pytest
import deeplabcut.utils.frameselectiontools as fst


@pytest.mark.parametrize(
"fps, duration, n_to_pick, start, end, index",
[
(32, 10, 10, 0, 1, None),
(16, 100, 50, 0, 1, list(range(100, 500, 5))),
(16, 100, 5, 0.25, 0.3, list(range(100, 500, 5))),
]
)
def test_uniform_frames(fps, duration, n_to_pick, start, end, index):
start_idx = int(math.floor(start * duration * fps))
end_idx = int(math.ceil(end * duration * fps))
if index is None:
valid_indices = list(range(start_idx, end_idx))
else:
valid_indices = [idx for idx in index if start_idx <= idx <= end_idx]

clip = Mock()
clip.fps = fps
clip.duration = duration
frames = fst.UniformFrames(clip, n_to_pick, start, end, index)
print(f"FPS: {fps}")
print(f"Duration: {duration}")
print(f"Selected Frames: {frames}")
print(f"Valid Indices: {valid_indices}")

# Check that we get the number of frames we asked for
assert len(frames) == n_to_pick, f"Wrong nb. of frames: {n_to_pick}!={len(frames)}"
# Check that all indices are valid
for index in frames:
assert index in valid_indices, f"Invalid index: {index} not in {valid_indices}"
# Check that all frames are unique
assert len(set(frames)) == len(frames), "Duplicate indices found"



@pytest.mark.parametrize(
"fps, nframes, n_to_pick, start, end, index",
[
(32, 320, 10, 0, 1, None),
(16, 1600, 50, 0, 1, list(range(100, 500, 5))),
(16, 1600, 5, 0.25, 0.3, list(range(100, 500, 5))),

]
)
def test_uniform_frames_cv2(fps, nframes, n_to_pick, start, end, index):
start_idx = int(math.floor(start * nframes))
end_idx = int(math.ceil(end * nframes))
if index is None:
valid_indices = list(range(start_idx, end_idx))
else:
valid_indices = [idx for idx in index if start_idx <= idx <= end_idx]

cap = Mock()
cap.fps = fps
cap.__len__ = Mock(return_value=nframes)
frames = fst.UniformFramescv2(cap, n_to_pick, start, end, index)
print(f"FPS: {fps}")
print(f"Nframes: {nframes}")
print(f"Selected Frames: {frames}")
print(f"Valid Indices: {valid_indices}")

# Check that we get the number of frames we asked for
assert len(frames) == n_to_pick, f"Wrong nb. of frames: {n_to_pick}!={len(frames)}"
# Check that all indices are valid
for index in frames:
assert index in valid_indices, f"Invalid index: {index} not in {valid_indices}"
# Check that all frames are unique
assert len(set(frames)) == len(frames), "Duplicate indices found"