Skip to content

Commit c4682b7

Browse files
committed
fix: ModelTrainer and HyperparameterTuner missing environment variables (5613)
1 parent e161199 commit c4682b7

File tree

2 files changed

+63
-1
lines changed

2 files changed

+63
-1
lines changed

sagemaker-train/src/sagemaker/train/tuner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1514,7 +1514,7 @@ def _build_training_job_definition(self, inputs):
15141514

15151515
# Pass through environment variables from model_trainer
15161516
env = getattr(model_trainer, "environment", None)
1517-
if env and isinstance(env, dict):
1517+
if env and isinstance(env, dict) and len(env) > 0:
15181518
definition.environment = env
15191519

15201520
# Pass through VPC config from model_trainer

sagemaker-train/tests/unit/train/test_tuner.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -574,3 +574,65 @@ def test_build_training_job_definition_includes_internal_channels(self):
574574
assert "train" in channel_names, "User 'train' channel should be included"
575575
assert "validation" in channel_names, "User 'validation' channel should be included"
576576
assert len(channel_names) == 4, "Should have exactly 4 channels"
577+
578+
def test_build_training_job_definition_includes_environment_variables(self):
579+
"""Test that _build_training_job_definition includes environment variables.
580+
581+
This test verifies the fix for GitHub issue #5613 where tuning jobs were missing
582+
environment variables that were set on the ModelTrainer.
583+
"""
584+
mock_trainer = _create_mock_model_trainer()
585+
mock_trainer.environment = {
586+
"RANDOM_STATE": "42",
587+
"MY_VAR": "hello",
588+
}
589+
590+
tuner = HyperparameterTuner(
591+
model_trainer=mock_trainer,
592+
objective_metric_name="accuracy",
593+
hyperparameter_ranges=_create_single_hp_range(),
594+
)
595+
596+
definition = tuner._build_training_job_definition(None)
597+
598+
assert definition.environment is not None, "Environment should be set"
599+
assert definition.environment == {
600+
"RANDOM_STATE": "42",
601+
"MY_VAR": "hello",
602+
}, "Environment variables should match those set on ModelTrainer"
603+
604+
def test_build_training_job_definition_with_empty_environment(self):
605+
"""Test that _build_training_job_definition handles empty environment dict."""
606+
mock_trainer = _create_mock_model_trainer()
607+
mock_trainer.environment = {}
608+
609+
tuner = HyperparameterTuner(
610+
model_trainer=mock_trainer,
611+
objective_metric_name="accuracy",
612+
hyperparameter_ranges=_create_single_hp_range(),
613+
)
614+
615+
definition = tuner._build_training_job_definition(None)
616+
617+
# Empty dict should not be set (falsy check)
618+
assert not getattr(definition, "environment", None), (
619+
"Empty environment dict should not be propagated"
620+
)
621+
622+
def test_build_training_job_definition_with_none_environment(self):
623+
"""Test that _build_training_job_definition handles None environment."""
624+
mock_trainer = _create_mock_model_trainer()
625+
mock_trainer.environment = None
626+
627+
tuner = HyperparameterTuner(
628+
model_trainer=mock_trainer,
629+
objective_metric_name="accuracy",
630+
hyperparameter_ranges=_create_single_hp_range(),
631+
)
632+
633+
definition = tuner._build_training_job_definition(None)
634+
635+
# None environment should not cause an error and should not be set
636+
assert not getattr(definition, "environment", None), (
637+
"None environment should not be propagated"
638+
)

0 commit comments

Comments
 (0)