Skip to content

Commit 420c548

Browse files
authored
fix: exclude null flex_start_max_run_duration_minutes in GPUConfig (#391)
1 parent 6b66c8f commit 420c548

2 files changed

Lines changed: 32 additions & 2 deletions

File tree

src/aignostics/platform/_sdk_metadata.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ class GPUConfig(BaseModel):
8686
le=60 * 60,
8787
description="Maximum run duration in minutes when using FLEX_START provisioning mode (1-3600). "
8888
"Required when provisioning_mode is FLEX_START, must be None otherwise.",
89+
exclude_if=lambda v: v is None, # Exclude from serialization if None
8990
)
9091

9192
@model_validator(mode="after")

tests/aignostics/platform/sdk_metadata_test.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,14 @@
77
import pytest
88
from pydantic import ValidationError
99

10+
from aignostics.platform import DEFAULT_GPU_PROVISIONING_MODE, DEFAULT_GPU_TYPE, DEFAULT_MAX_GPUS_PER_SLIDE
1011
from aignostics.platform._sdk_metadata import (
1112
ITEM_SDK_METADATA_SCHEMA_VERSION,
1213
SDK_METADATA_SCHEMA_VERSION,
1314
VALIDATION_CASE_TAG_PREFIX,
15+
GPUConfig,
16+
GPUType,
17+
ProvisioningMode,
1418
ValidationCase,
1519
build_item_sdk_metadata,
1620
build_run_sdk_metadata,
@@ -996,8 +1000,6 @@ def test_pipeline_config_defaults() -> None:
9961000
"""Test that pipeline configuration uses correct defaults."""
9971001
from aignostics.platform import (
9981002
DEFAULT_CPU_PROVISIONING_MODE,
999-
DEFAULT_GPU_PROVISIONING_MODE,
1000-
DEFAULT_GPU_TYPE,
10011003
DEFAULT_MAX_GPUS_PER_SLIDE,
10021004
PipelineConfig,
10031005
)
@@ -1255,3 +1257,30 @@ def test_metadata_with_invalid_validation_case_tag() -> None:
12551257
with pytest.raises(ValidationError) as exc:
12561258
validate_run_sdk_metadata(metadata)
12571259
assert "validation_case" in str(exc.value)
1260+
1261+
1262+
class TestGPUConfig:
1263+
"""Test cases for GPUConfig model."""
1264+
1265+
@pytest.mark.unit
1266+
@staticmethod
1267+
def test_model_dump_should_include_flex_start_max_duration_if_provided() -> None:
1268+
"""Test that flex_start_max_run_duration_minutes is included in model dump if provided."""
1269+
config = GPUConfig(
1270+
gpu_type=GPUType.L4,
1271+
provisioning_mode=ProvisioningMode.FLEX_START,
1272+
max_gpus_per_slide=DEFAULT_MAX_GPUS_PER_SLIDE,
1273+
flex_start_max_run_duration_minutes=1,
1274+
)
1275+
assert "flex_start_max_run_duration_minutes" in config.model_dump()
1276+
1277+
@pytest.mark.unit
1278+
@staticmethod
1279+
def test_model_dump_should_exclude_flex_start_max_duration_if_not_provided() -> None:
1280+
"""Test that flex_start_max_run_duration_minutes is excluded in model dump if not provided."""
1281+
config = GPUConfig(
1282+
gpu_type=GPUType.L4,
1283+
provisioning_mode=ProvisioningMode.SPOT,
1284+
max_gpus_per_slide=DEFAULT_MAX_GPUS_PER_SLIDE,
1285+
)
1286+
assert "flex_start_max_run_duration_minutes" not in config.model_dump()

0 commit comments

Comments
 (0)