From c4682b7fc0512a721cf9f38713a5b8bb4f6026a5 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Mon, 30 Mar 2026 20:24:28 -0400 Subject: [PATCH] fix: ModelTrainer and HyperparameterTuner missing environment variables (5613) --- sagemaker-train/src/sagemaker/train/tuner.py | 2 +- .../tests/unit/train/test_tuner.py | 62 +++++++++++++++++++ 2 files changed, 63 insertions(+), 1 deletion(-) diff --git a/sagemaker-train/src/sagemaker/train/tuner.py b/sagemaker-train/src/sagemaker/train/tuner.py index cde1598481..dad9cd2406 100644 --- a/sagemaker-train/src/sagemaker/train/tuner.py +++ b/sagemaker-train/src/sagemaker/train/tuner.py @@ -1514,7 +1514,7 @@ def _build_training_job_definition(self, inputs): # Pass through environment variables from model_trainer env = getattr(model_trainer, "environment", None) - if env and isinstance(env, dict): + if env and isinstance(env, dict) and len(env) > 0: definition.environment = env # Pass through VPC config from model_trainer diff --git a/sagemaker-train/tests/unit/train/test_tuner.py b/sagemaker-train/tests/unit/train/test_tuner.py index c0255eac47..46f74b010a 100644 --- a/sagemaker-train/tests/unit/train/test_tuner.py +++ b/sagemaker-train/tests/unit/train/test_tuner.py @@ -574,3 +574,65 @@ def test_build_training_job_definition_includes_internal_channels(self): assert "train" in channel_names, "User 'train' channel should be included" assert "validation" in channel_names, "User 'validation' channel should be included" assert len(channel_names) == 4, "Should have exactly 4 channels" + + def test_build_training_job_definition_includes_environment_variables(self): + """Test that _build_training_job_definition includes environment variables. + + This test verifies the fix for GitHub issue #5613 where tuning jobs were missing + environment variables that were set on the ModelTrainer. + """ + mock_trainer = _create_mock_model_trainer() + mock_trainer.environment = { + "RANDOM_STATE": "42", + "MY_VAR": "hello", + } + + tuner = HyperparameterTuner( + model_trainer=mock_trainer, + objective_metric_name="accuracy", + hyperparameter_ranges=_create_single_hp_range(), + ) + + definition = tuner._build_training_job_definition(None) + + assert definition.environment is not None, "Environment should be set" + assert definition.environment == { + "RANDOM_STATE": "42", + "MY_VAR": "hello", + }, "Environment variables should match those set on ModelTrainer" + + def test_build_training_job_definition_with_empty_environment(self): + """Test that _build_training_job_definition handles empty environment dict.""" + mock_trainer = _create_mock_model_trainer() + mock_trainer.environment = {} + + tuner = HyperparameterTuner( + model_trainer=mock_trainer, + objective_metric_name="accuracy", + hyperparameter_ranges=_create_single_hp_range(), + ) + + definition = tuner._build_training_job_definition(None) + + # Empty dict should not be set (falsy check) + assert not getattr(definition, "environment", None), ( + "Empty environment dict should not be propagated" + ) + + def test_build_training_job_definition_with_none_environment(self): + """Test that _build_training_job_definition handles None environment.""" + mock_trainer = _create_mock_model_trainer() + mock_trainer.environment = None + + tuner = HyperparameterTuner( + model_trainer=mock_trainer, + objective_metric_name="accuracy", + hyperparameter_ranges=_create_single_hp_range(), + ) + + definition = tuner._build_training_job_definition(None) + + # None environment should not cause an error and should not be set + assert not getattr(definition, "environment", None), ( + "None environment should not be propagated" + )