Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sagemaker-train/src/sagemaker/train/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1514,7 +1514,7 @@ def _build_training_job_definition(self, inputs):

# Pass through environment variables from model_trainer
env = getattr(model_trainer, "environment", None)
if env and isinstance(env, dict):
if env and isinstance(env, dict) and len(env) > 0:
definition.environment = env

# Pass through VPC config from model_trainer
Expand Down
62 changes: 62 additions & 0 deletions sagemaker-train/tests/unit/train/test_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,3 +574,65 @@ def test_build_training_job_definition_includes_internal_channels(self):
assert "train" in channel_names, "User 'train' channel should be included"
assert "validation" in channel_names, "User 'validation' channel should be included"
assert len(channel_names) == 4, "Should have exactly 4 channels"

def test_build_training_job_definition_includes_environment_variables(self):
"""Test that _build_training_job_definition includes environment variables.

This test verifies the fix for GitHub issue #5613 where tuning jobs were missing
environment variables that were set on the ModelTrainer.
"""
mock_trainer = _create_mock_model_trainer()
mock_trainer.environment = {
"RANDOM_STATE": "42",
"MY_VAR": "hello",
}

tuner = HyperparameterTuner(
model_trainer=mock_trainer,
objective_metric_name="accuracy",
hyperparameter_ranges=_create_single_hp_range(),
)

definition = tuner._build_training_job_definition(None)

assert definition.environment is not None, "Environment should be set"
assert definition.environment == {
"RANDOM_STATE": "42",
"MY_VAR": "hello",
}, "Environment variables should match those set on ModelTrainer"

def test_build_training_job_definition_with_empty_environment(self):
"""Test that _build_training_job_definition handles empty environment dict."""
mock_trainer = _create_mock_model_trainer()
mock_trainer.environment = {}

tuner = HyperparameterTuner(
model_trainer=mock_trainer,
objective_metric_name="accuracy",
hyperparameter_ranges=_create_single_hp_range(),
)

definition = tuner._build_training_job_definition(None)

# Empty dict should not be set (falsy check)
assert not getattr(definition, "environment", None), (
"Empty environment dict should not be propagated"
)

def test_build_training_job_definition_with_none_environment(self):
"""Test that _build_training_job_definition handles None environment."""
mock_trainer = _create_mock_model_trainer()
mock_trainer.environment = None

tuner = HyperparameterTuner(
model_trainer=mock_trainer,
objective_metric_name="accuracy",
hyperparameter_ranges=_create_single_hp_range(),
)

definition = tuner._build_training_job_definition(None)

# None environment should not cause an error and should not be set
assert not getattr(definition, "environment", None), (
"None environment should not be propagated"
)
Loading