|
7 | 7 | import pytest |
8 | 8 | from pydantic import ValidationError |
9 | 9 |
|
| 10 | +from aignostics.platform import DEFAULT_GPU_PROVISIONING_MODE, DEFAULT_GPU_TYPE, DEFAULT_MAX_GPUS_PER_SLIDE |
10 | 11 | from aignostics.platform._sdk_metadata import ( |
11 | 12 | ITEM_SDK_METADATA_SCHEMA_VERSION, |
12 | 13 | SDK_METADATA_SCHEMA_VERSION, |
13 | 14 | VALIDATION_CASE_TAG_PREFIX, |
| 15 | + GPUConfig, |
| 16 | + GPUType, |
| 17 | + ProvisioningMode, |
14 | 18 | ValidationCase, |
15 | 19 | build_item_sdk_metadata, |
16 | 20 | build_run_sdk_metadata, |
@@ -996,8 +1000,6 @@ def test_pipeline_config_defaults() -> None: |
996 | 1000 | """Test that pipeline configuration uses correct defaults.""" |
997 | 1001 | from aignostics.platform import ( |
998 | 1002 | DEFAULT_CPU_PROVISIONING_MODE, |
999 | | - DEFAULT_GPU_PROVISIONING_MODE, |
1000 | | - DEFAULT_GPU_TYPE, |
1001 | 1003 | DEFAULT_MAX_GPUS_PER_SLIDE, |
1002 | 1004 | PipelineConfig, |
1003 | 1005 | ) |
@@ -1255,3 +1257,30 @@ def test_metadata_with_invalid_validation_case_tag() -> None: |
1255 | 1257 | with pytest.raises(ValidationError) as exc: |
1256 | 1258 | validate_run_sdk_metadata(metadata) |
1257 | 1259 | assert "validation_case" in str(exc.value) |
| 1260 | + |
| 1261 | + |
| 1262 | +class TestGPUConfig: |
| 1263 | + """Test cases for GPUConfig model.""" |
| 1264 | + |
| 1265 | + @pytest.mark.unit |
| 1266 | + @staticmethod |
| 1267 | + def test_model_dump_should_include_flex_start_max_duration_if_provided() -> None: |
| 1268 | + """Test that flex_start_max_run_duration_minutes is included in model dump if provided.""" |
| 1269 | + config = GPUConfig( |
| 1270 | + gpu_type=GPUType.L4, |
| 1271 | + provisioning_mode=ProvisioningMode.FLEX_START, |
| 1272 | + max_gpus_per_slide=DEFAULT_MAX_GPUS_PER_SLIDE, |
| 1273 | + flex_start_max_run_duration_minutes=1, |
| 1274 | + ) |
| 1275 | + assert "flex_start_max_run_duration_minutes" in config.model_dump() |
| 1276 | + |
| 1277 | + @pytest.mark.unit |
| 1278 | + @staticmethod |
| 1279 | + def test_model_dump_should_exclude_flex_start_max_duration_if_not_provided() -> None: |
| 1280 | + """Test that flex_start_max_run_duration_minutes is excluded in model dump if not provided.""" |
| 1281 | + config = GPUConfig( |
| 1282 | + gpu_type=GPUType.L4, |
| 1283 | + provisioning_mode=ProvisioningMode.SPOT, |
| 1284 | + max_gpus_per_slide=DEFAULT_MAX_GPUS_PER_SLIDE, |
| 1285 | + ) |
| 1286 | + assert "flex_start_max_run_duration_minutes" not in config.model_dump() |
0 commit comments