|
2 | 2 | from __future__ import annotations |
3 | 3 |
|
4 | 4 | import json |
| 5 | +from enum import Enum |
5 | 6 | from pathlib import Path |
6 | 7 | from typing import Any, Literal |
7 | 8 |
|
|
10 | 11 | Rotation = Literal[0, 90, 180, 270] |
11 | 12 | TileLayout = Literal["auto", "2x2", "1x4", "4x1"] |
12 | 13 | Precision = Literal["FP32", "FP16"] |
| 14 | +ModelType = Literal["pytorch", "tensorflow"] |
13 | 15 |
|
14 | 16 |
|
15 | 17 | class CameraSettings(BaseModel): |
@@ -239,14 +241,46 @@ class DLCProcessorSettings(BaseModel): |
239 | 241 | resize: float = Field(default=1.0, gt=0) |
240 | 242 | precision: Precision = "FP32" |
241 | 243 | additional_options: dict[str, Any] = Field(default_factory=dict) |
242 | | - model_type: Literal["pytorch"] = "pytorch" |
| 244 | + model_type: ModelType = "pytorch" |
243 | 245 | single_animal: bool = True |
244 | 246 |
|
245 | 247 | @field_validator("dynamic", mode="before") |
246 | 248 | @classmethod |
247 | 249 | def _coerce_dynamic(cls, v): |
248 | 250 | return DynamicCropModel.from_tupleish(v) |
249 | 251 |
|
| 252 | + @field_validator("model_type", mode="before") |
| 253 | + @classmethod |
| 254 | + def _coerce_model_type(cls, v): |
| 255 | + """ |
| 256 | + Accept: |
| 257 | + - "pytorch"/"tensorflow"/etc as strings |
| 258 | + - Enum instances (e.g. Engine.PYTORCH) and store their .value |
| 259 | + Always return a lowercase string. |
| 260 | + """ |
| 261 | + if v is None or v == "": |
| 262 | + return "pytorch" |
| 263 | + |
| 264 | + # If caller passed Engine enum or any Enum, use its value |
| 265 | + if isinstance(v, Enum): |
| 266 | + v = v.value |
| 267 | + |
| 268 | + # If caller passed something with a `.value` attribute (defensive) |
| 269 | + if not isinstance(v, str) and hasattr(v, "value"): |
| 270 | + v = v.value |
| 271 | + |
| 272 | + if not isinstance(v, str): |
| 273 | + raise TypeError(f"model_type must be a string or Enum, got {type(v)!r}") |
| 274 | + |
| 275 | + v = v.strip().lower() |
| 276 | + |
| 277 | + # Optional: enforce allowed values |
| 278 | + allowed = {"pytorch", "tensorflow"} |
| 279 | + if v not in allowed: |
| 280 | + raise ValueError(f"Unknown model type: {v!r}. Allowed: {sorted(allowed)}") |
| 281 | + |
| 282 | + return v |
| 283 | + |
250 | 284 |
|
251 | 285 | class BoundingBoxSettings(BaseModel): |
252 | 286 | enabled: bool = False |
|
0 commit comments