-
Notifications
You must be signed in to change notification settings - Fork 54
Expand file tree
/
Copy pathutils.py
More file actions
203 lines (160 loc) · 7.1 KB
/
utils.py
File metadata and controls
203 lines (160 loc) · 7.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
"""
Utils for the DLC-Live Model Zoo
"""
# NOTE JR 2026-23-01: This file contains duplicated code from the DeepLabCut main repository.
# This should be removed once a solution is found to address duplicate code.
import copy
import logging
from pathlib import Path
from dlclibrary.dlcmodelzoo.modelzoo_download import download_huggingface_model
from dlclibrary.dlcmodelzoo.modelzoo_download import (
_load_model_names as huggingface_model_paths,
)
from ruamel.yaml import YAML
from dlclive.modelzoo.resolve_config import update_config
from dlclive.pose_estimation_pytorch.models.detectors.torchvision import (
SUPPORTED_TORCHVISION_DETECTORS,
)
_MODELZOO_PATH = Path(__file__).parent
def get_super_animal_model_config_path(model_name: str) -> Path:
"""Get the path to the model configuration file for a model and validate choice of model"""
cfg_path = _MODELZOO_PATH / "model_configs" / f"{model_name}.yaml"
if not cfg_path.exists():
raise FileNotFoundError(
f"Modelzoo model configuration file not found: {cfg_path} Available models: {list_available_models()}"
)
return cfg_path
def get_super_animal_project_config_path(super_animal: str) -> Path:
"""Get the path to the project configuration file for a project and validate choice of project"""
cfg_path = _MODELZOO_PATH / "project_configs" / f"{super_animal}.yaml"
if not cfg_path.exists():
raise FileNotFoundError(
f"Modelzoo project configuration file not found: {cfg_path} Available projects: {list_available_projects()}"
)
return cfg_path
def get_snapshot_folder_path() -> Path:
return _MODELZOO_PATH / "snapshots"
def list_available_models() -> list[str]:
return [p.stem for p in _MODELZOO_PATH.glob("model_configs/*.yaml")]
def list_available_projects() -> list[str]:
return [p.stem for p in _MODELZOO_PATH.glob("project_configs/*.yaml")]
def list_available_combinations() -> list[str]:
return list(huggingface_model_paths.keys())
def read_config_as_dict(config_path: str | Path) -> dict:
"""
Args:
config_path: the path to the configuration file to load
Returns:
The configuration file with pure Python classes
"""
with open(config_path) as f:
cfg = YAML(typ="safe", pure=True).load(f)
return cfg
# NOTE - DUPLICATED @deruyter92 2026-01-23: Copied from the original DeepLabCut codebase
# from deeplabcut/pose_estimation_pytorch/config/make_pose_config.py
def add_metadata(
project_config: dict,
config: dict,
) -> dict:
"""Adds metadata to a pytorch pose configuration
Args:
project_config: the project configuration
config: the pytorch pose configuration
pose_config_path: the path where the pytorch pose configuration will be saved
Returns:
the configuration with a `meta` key added
"""
config = copy.deepcopy(config)
config["metadata"] = {
"project_path": project_config["project_path"],
"pose_config_path": "",
"bodyparts": project_config.get("multianimalbodyparts")
or project_config["bodyparts"],
"unique_bodyparts": project_config.get("uniquebodyparts", []),
"individuals": project_config.get("individuals", ["animal"]),
"with_identity": project_config.get("identity", False),
}
return config
def _get_torchvision_detector_config(detector_name: str) -> dict:
"""Get a torchvision detector configuration for the superanimal humanbody model"""
if detector_name is None:
raise ValueError(
f"Detector name is required for superanimal humanbody models. Must be one of {SUPPORTED_TORCHVISION_DETECTORS}."
)
if detector_name not in SUPPORTED_TORCHVISION_DETECTORS:
raise ValueError(
f"Unsupported humanbody detector {detector_name}. Should be one of {SUPPORTED_TORCHVISION_DETECTORS}"
)
return {
"type": "TorchvisionDetectorAdaptor",
"model": detector_name,
"weights": "COCO_V1",
"num_classes": None,
"box_score_thresh": 0.6,
}
# NOTE - DUPLICATED @deruyter92 2026-01-23: Copied from the original DeepLabCut codebase
# from deeplabcut/pose_estimation_pytorch/modelzoo/utils.py
def load_super_animal_config(
super_animal: str,
model_name: str,
detector_name: str | None = None,
max_individuals: int = 30,
device: str | None = None,
) -> dict:
"""Loads the model configuration file for a model, detector and SuperAnimal
Args:
super_animal: The name of the SuperAnimal for which to create the model config.
model_name: The name of the model for which to create the model config.
detector_name: The name of the detector for which to create the model config.
max_individuals: The maximum number of detections to make in an image
device: The device to use to train/run inference on the model
Returns:
The model configuration for a SuperAnimal-pretrained model.
"""
project_cfg_path = get_super_animal_project_config_path(super_animal=super_animal)
project_config = read_config_as_dict(project_cfg_path)
model_cfg_path = get_super_animal_model_config_path(model_name=model_name)
model_config = read_config_as_dict(model_cfg_path)
model_config = add_metadata(project_config, model_config)
model_config = update_config(model_config, max_individuals, device)
if detector_name is None:
model_config["method"] = "BU"
else:
model_config["method"] = "TD"
detector_cfg_path = get_super_animal_model_config_path(model_name=detector_name)
detector_cfg = read_config_as_dict(detector_cfg_path)
model_config["detector"] = detector_cfg
if super_animal == "superanimal_humanbody":
# Raises ValueError if Detector name is not one of SUPPORTED_TORCHVISION_DETECTORS
torchvision_detector_config = _get_torchvision_detector_config(detector_name)
model_config["detector"]["model"] = torchvision_detector_config
return model_config
def download_super_animal_snapshot(dataset: str, model_name: str) -> Path:
"""Downloads a SuperAnimal snapshot
Args:
dataset: The name of the SuperAnimal dataset for which to download a snapshot.
model_name: The name of the model for which to download a snapshot.
Returns:
The path to the downloaded snapshot.
Raises:
RuntimeError if the model fails to download.
"""
snapshot_dir = get_snapshot_folder_path()
model_name = f"{dataset}_{model_name}"
model_filename = f"{model_name}.pt"
model_path = snapshot_dir / model_filename
if model_path.exists():
logging.info(f"Snapshot {model_path} already exists, skipping download")
return model_path
try:
download_huggingface_model(
model_name, target_dir=str(snapshot_dir), rename_mapping=model_filename
)
if not model_path.exists():
raise RuntimeError(f"Failed to download {model_name} to {model_path}")
except Exception as e:
logging.error(
f"Failed to download superanimal snapshot {model_name} to {model_path}: {e}"
)
raise e
return model_path