-
Notifications
You must be signed in to change notification settings - Fork 1.2k
fix: ModelTrainer and HyperparameterTuner missing environment variables (5613) #5725
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
f1ea9d5
dec47ab
862ff2d
1e8e693
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
|
aviruthen marked this conversation as resolved.
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1504,6 +1504,11 @@ def _build_training_job_definition(self, inputs): | |
| model_trainer.stopping_condition.max_wait_time_in_seconds | ||
| ) | ||
|
|
||
|
aviruthen marked this conversation as resolved.
aviruthen marked this conversation as resolved.
|
||
| # Get environment variables from model_trainer | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The truthiness check Also, the |
||
| env = getattr(model_trainer, "environment", None) | ||
| if not env or not isinstance(env, dict): | ||
| env = None | ||
|
|
||
| definition = HyperParameterTrainingJobDefinition( | ||
| algorithm_specification=algorithm_spec, | ||
| role_arn=model_trainer.role, | ||
|
|
@@ -1513,13 +1518,9 @@ def _build_training_job_definition(self, inputs): | |
| stopping_condition=stopping_condition, | ||
| static_hyper_parameters=getattr(self, "static_hyperparameters", None) or {}, | ||
| enable_managed_spot_training=model_trainer.compute.enable_managed_spot_training, | ||
| environment=env, | ||
| ) | ||
|
|
||
| # Pass through environment variables from model_trainer | ||
| env = getattr(model_trainer, "environment", None) | ||
| if env and isinstance(env, dict): | ||
| definition.environment = env | ||
|
|
||
| # Pass through VPC config from model_trainer | ||
| networking = getattr(model_trainer, "networking", None) | ||
| if networking and hasattr(networking, "_to_vpc_config"): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -596,3 +596,61 @@ def test_build_training_job_definition_includes_spot_params(self): | |
| assert isinstance( | ||
| definition.stopping_condition.max_wait_time_in_seconds, int | ||
|
aviruthen marked this conversation as resolved.
|
||
| ), "Max wait time should be set" | ||
|
|
||
|
aviruthen marked this conversation as resolved.
|
||
| def test_build_training_job_definition_includes_environment_variables(self): | ||
| """Test that _build_training_job_definition includes environment variables. | ||
|
|
||
| This test verifies the fix for GitHub issue #5613 where tuning jobs were | ||
| missing environment variables that were set on the ModelTrainer. | ||
| """ | ||
| mock_trainer = _create_mock_model_trainer() | ||
| mock_trainer.environment = { | ||
| "FOO": "bar", | ||
| "RANDOM_STATE": "42", | ||
| } | ||
|
|
||
| tuner = HyperparameterTuner( | ||
| model_trainer=mock_trainer, | ||
| objective_metric_name="accuracy", | ||
| hyperparameter_ranges=_create_single_hp_range(), | ||
| ) | ||
|
|
||
| definition = tuner._build_training_job_definition(None) | ||
|
|
||
| assert definition.environment is not None, "Environment should not be None" | ||
| assert definition.environment == { | ||
| "FOO": "bar", | ||
| "RANDOM_STATE": "42", | ||
| }, "Environment variables should match those set on ModelTrainer" | ||
|
|
||
| def test_build_training_job_definition_with_none_environment(self): | ||
| """Test that _build_training_job_definition handles None environment gracefully.""" | ||
| mock_trainer = _create_mock_model_trainer() | ||
| mock_trainer.environment = None | ||
|
|
||
| tuner = HyperparameterTuner( | ||
| model_trainer=mock_trainer, | ||
| objective_metric_name="accuracy", | ||
| hyperparameter_ranges=_create_single_hp_range(), | ||
| ) | ||
|
|
||
| definition = tuner._build_training_job_definition(None) | ||
|
|
||
| assert definition.environment is None, "Environment should be None when not set" | ||
|
|
||
| def test_build_training_job_definition_with_empty_environment(self): | ||
| """Test that _build_training_job_definition handles empty environment gracefully.""" | ||
| mock_trainer = _create_mock_model_trainer() | ||
| mock_trainer.environment = {} | ||
|
|
||
| tuner = HyperparameterTuner( | ||
| model_trainer=mock_trainer, | ||
| objective_metric_name="accuracy", | ||
| hyperparameter_ranges=_create_single_hp_range(), | ||
| ) | ||
|
|
||
| definition = tuner._build_training_job_definition(None) | ||
|
|
||
| assert definition.environment is None, ( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: The assertion message says "Environment should be None when empty dict is provided" — this documents the behavior but consider whether this is actually the desired UX. A user who explicitly sets |
||
| "Environment should be None when empty dict is provided" | ||
| ) | ||
Uh oh!
There was an error while loading. Please reload this page.