Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion deeplabcut/create_project/modelzoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def create_pretrained_project(
auxiliaryfunctions.edit_config(cfg, dict_)

# downloading base encoder / not required unless on re-trains (but when a training set is created this happens anyway)
# model_path, num_shuffles=auxfun_models.Check4weights(pose_cfg['net_type'], parent_path, num_shuffles= 1)
# model_path = auxfun_models.check_for_weights(pose_cfg['net_type'], parent_path)

# Updating training and test pose_cfg:
snapshotname = [fn for fn in os.listdir(train_dir) if ".meta" in fn][0].split(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,8 @@ def create_multianimaltraining_dataset(
# Loading the encoder (if necessary downloading from TF)
dlcparent_path = auxiliaryfunctions.get_deeplabcut_path()
defaultconfigfile = os.path.join(dlcparent_path, "pose_cfg.yaml")
model_path, num_shuffles = auxfun_models.check_for_weights(
net_type, Path(dlcparent_path), num_shuffles
model_path = auxfun_models.check_for_weights(
net_type, Path(dlcparent_path)
)

if Shuffles is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -916,8 +916,8 @@ def create_training_dataset(
defaultconfigfile = os.path.join(dlcparent_path, "pose_cfg.yaml")
elif posecfg_template:
defaultconfigfile = posecfg_template
model_path, num_shuffles = auxfun_models.check_for_weights(
net_type, Path(dlcparent_path), num_shuffles
model_path = auxfun_models.check_for_weights(
net_type, Path(dlcparent_path)
)

if Shuffles is None:
Expand Down
6 changes: 3 additions & 3 deletions deeplabcut/utils/auxfun_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,15 @@
}


def check_for_weights(modeltype, parent_path, num_shuffles):
def check_for_weights(modeltype, parent_path):
"""gets local path to network weights and checks if they are present. If not, downloads them from tensorflow.org"""

if modeltype not in MODELTYPE_FILEPATH_MAP.keys():
print(
"Currently ResNet (50, 101, 152), MobilenetV2 (1, 0.75, 0.5 and 0.35) and EfficientNet (b0-b6) are supported, please change 'resnet' entry in config.yaml!"
)
# Exit the function early if an unknown modeltype is provided.
return parent_path, -1
return parent_path

exists = False
model_path = parent_path / MODELTYPE_FILEPATH_MAP[modeltype]
Expand All @@ -70,7 +70,7 @@ def check_for_weights(modeltype, parent_path, num_shuffles):
else:
download_weights(modeltype, model_path)

return str(model_path), num_shuffles
return str(model_path)


def download_weights(modeltype, model_path):
Expand Down
7 changes: 3 additions & 4 deletions tests/test_auxfun_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_filepaths_for_modeltypes(self):
"deeplabcut.utils.auxfun_models.download_weights"
) as mocked_download:
for modeltype, expected_path in MODELTYPE_FILEPATH_MAP.items():
actual_path, _ = check_for_weights(modeltype, Path(tmpdir), 1)
actual_path = check_for_weights(modeltype, Path(tmpdir))
self.assertIn(str(expected_path), actual_path)
if "efficientnet" in modeltype:
mocked_download.assert_called_with(
Expand All @@ -37,8 +37,7 @@ def test_filepaths_for_modeltypes(self):
)

def test_bad_modeltype(self):
actual_path, actual_num_shuffles = check_for_weights(
"dummymodel", "nonexistentpath", 1
actual_path = check_for_weights(
"dummymodel", "nonexistentpath"
)
self.assertEqual(actual_path, "nonexistentpath")
self.assertEqual(actual_num_shuffles, -1)