Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions sagemaker-train/src/sagemaker/train/tuner.py
Comment thread
aviruthen marked this conversation as resolved.
Comment thread
aviruthen marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -1504,6 +1504,11 @@ def _build_training_job_definition(self, inputs):
model_trainer.stopping_condition.max_wait_time_in_seconds
)

Comment thread
aviruthen marked this conversation as resolved.
Comment thread
aviruthen marked this conversation as resolved.
# Get environment variables from model_trainer
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The truthiness check not env will coerce an empty dict {} to True, causing it to be set to None. While the test covers this, silently converting a user-provided {} to None could be surprising. Consider whether an empty dict should be passed through as-is (the API would accept it), or if this is intentional. If intentional, a brief comment explaining why empty dicts are normalized to None would help future maintainers.

Also, the isinstance(env, dict) check is defensive — if ModelTrainer.environment has a type annotation of dict | None, Pydantic validation should already enforce this. Is this guard necessary?

env = getattr(model_trainer, "environment", None)
if not env or not isinstance(env, dict):
env = None

definition = HyperParameterTrainingJobDefinition(
algorithm_specification=algorithm_spec,
role_arn=model_trainer.role,
Expand All @@ -1513,13 +1518,9 @@ def _build_training_job_definition(self, inputs):
stopping_condition=stopping_condition,
static_hyper_parameters=getattr(self, "static_hyperparameters", None) or {},
enable_managed_spot_training=model_trainer.compute.enable_managed_spot_training,
environment=env,
)

# Pass through environment variables from model_trainer
env = getattr(model_trainer, "environment", None)
if env and isinstance(env, dict):
definition.environment = env

# Pass through VPC config from model_trainer
networking = getattr(model_trainer, "networking", None)
if networking and hasattr(networking, "_to_vpc_config"):
Expand Down
58 changes: 58 additions & 0 deletions sagemaker-train/tests/unit/train/test_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,3 +596,61 @@ def test_build_training_job_definition_includes_spot_params(self):
assert isinstance(
definition.stopping_condition.max_wait_time_in_seconds, int
Comment thread
aviruthen marked this conversation as resolved.
), "Max wait time should be set"

Comment thread
aviruthen marked this conversation as resolved.
def test_build_training_job_definition_includes_environment_variables(self):
"""Test that _build_training_job_definition includes environment variables.

This test verifies the fix for GitHub issue #5613 where tuning jobs were
missing environment variables that were set on the ModelTrainer.
"""
mock_trainer = _create_mock_model_trainer()
mock_trainer.environment = {
"FOO": "bar",
"RANDOM_STATE": "42",
}

tuner = HyperparameterTuner(
model_trainer=mock_trainer,
objective_metric_name="accuracy",
hyperparameter_ranges=_create_single_hp_range(),
)

definition = tuner._build_training_job_definition(None)

assert definition.environment is not None, "Environment should not be None"
assert definition.environment == {
"FOO": "bar",
"RANDOM_STATE": "42",
}, "Environment variables should match those set on ModelTrainer"

def test_build_training_job_definition_with_none_environment(self):
"""Test that _build_training_job_definition handles None environment gracefully."""
mock_trainer = _create_mock_model_trainer()
mock_trainer.environment = None

tuner = HyperparameterTuner(
model_trainer=mock_trainer,
objective_metric_name="accuracy",
hyperparameter_ranges=_create_single_hp_range(),
)

definition = tuner._build_training_job_definition(None)

assert definition.environment is None, "Environment should be None when not set"

def test_build_training_job_definition_with_empty_environment(self):
"""Test that _build_training_job_definition handles empty environment gracefully."""
mock_trainer = _create_mock_model_trainer()
mock_trainer.environment = {}

tuner = HyperparameterTuner(
model_trainer=mock_trainer,
objective_metric_name="accuracy",
hyperparameter_ranges=_create_single_hp_range(),
)

definition = tuner._build_training_job_definition(None)

assert definition.environment is None, (
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: The assertion message says "Environment should be None when empty dict is provided" — this documents the behavior but consider whether this is actually the desired UX. A user who explicitly sets environment={} might not expect it to be silently dropped. If this is intentional, it's fine, but worth confirming with the team.

"Environment should be None when empty dict is provided"
)
Loading