diff --git a/sagemaker-core/tests/unit/workflow/test_utilities.py b/sagemaker-core/tests/unit/workflow/test_utilities.py index 5e9ed7bbbd..0b96ae3f49 100644 --- a/sagemaker-core/tests/unit/workflow/test_utilities.py +++ b/sagemaker-core/tests/unit/workflow/test_utilities.py @@ -32,6 +32,7 @@ from sagemaker.core.workflow.entities import Entity from sagemaker.core.workflow.parameters import Parameter from sagemaker.core.workflow.pipeline_context import _StepArguments +from sagemaker.core.workflow.utilities import get_code_hash class MockEntity(Entity): @@ -308,6 +309,49 @@ def test_get_training_code_hash_entry_point_only(self): assert len(result_with_deps) == 64 assert result_no_deps != result_with_deps + def test_get_training_code_hash_with_source_dir_none_dependencies(self): + """Test get_training_code_hash with source_dir and None dependencies does not raise TypeError""" + with tempfile.TemporaryDirectory() as temp_dir: + entry_file = Path(temp_dir, "train.py") + entry_file.write_text("print('training')") + + # This is the exact scenario from the bug report: dependencies=None + result = get_training_code_hash( + entry_point=str(entry_file), source_dir=temp_dir, dependencies=None + ) + + assert result is not None + assert len(result) == 64 + + def test_get_training_code_hash_entry_point_only_none_dependencies(self): + """Test get_training_code_hash with entry_point only and None dependencies does not raise TypeError""" + with tempfile.TemporaryDirectory() as temp_dir: + entry_file = Path(temp_dir, "train.py") + entry_file.write_text("print('training')") + + # entry_point only, no source_dir, dependencies=None + result = get_training_code_hash( + entry_point=str(entry_file), source_dir=None, dependencies=None + ) + + assert result is not None + assert len(result) == 64 + + def test_get_training_code_hash_default_dependencies(self): + """Test get_training_code_hash with default dependencies parameter (not passed)""" + with tempfile.TemporaryDirectory() as temp_dir: + entry_file = Path(temp_dir, "train.py") + entry_file.write_text("print('training')") + + # Not passing dependencies at all - should use default None + result = get_training_code_hash( + entry_point=str(entry_file), source_dir=temp_dir + ) + + assert result is not None + assert len(result) == 64 + + def test_get_training_code_hash_s3_uri(self): """Test get_training_code_hash with S3 URI returns None""" result = get_training_code_hash( @@ -325,6 +369,54 @@ def test_get_training_code_hash_pipeline_variable(self): assert result is None + @pytest.mark.skip(reason="Requires sagemaker-mlops module which is not installed in sagemaker-core tests") + def test_get_code_hash_training_step_none_requirements(self): + """Test get_code_hash with TrainingStep where source_code.requirements is None""" + from sagemaker.mlops.workflow.steps import TrainingStep + + with tempfile.TemporaryDirectory() as temp_dir: + entry_file = Path(temp_dir, "train.py") + entry_file.write_text("print('training')") + + mock_source_code = Mock() + mock_source_code.source_dir = temp_dir + mock_source_code.requirements = None # This is the bug scenario + mock_source_code.entry_script = str(entry_file) + + mock_model_trainer = Mock() + mock_model_trainer.source_code = mock_source_code + + mock_step_args = Mock(spec=_StepArguments) + mock_step_args.func_args = [mock_model_trainer] + + mock_step = Mock(spec=TrainingStep) + mock_step.step_args = mock_step_args + + # This should not raise TypeError + result = get_code_hash(mock_step) + + assert result is not None + assert len(result) == 64 + + @pytest.mark.skip(reason="Requires sagemaker-mlops module which is not installed in sagemaker-core tests") + def test_get_code_hash_training_step_no_source_code(self): + """Test get_code_hash with TrainingStep where source_code is None""" + from sagemaker.mlops.workflow.steps import TrainingStep + + mock_model_trainer = Mock() + mock_model_trainer.source_code = None + + mock_step_args = Mock(spec=_StepArguments) + mock_step_args.func_args = [mock_model_trainer] + + mock_step = Mock(spec=TrainingStep) + mock_step.step_args = mock_step_args + + result = get_code_hash(mock_step) + + assert result is None + + def test_validate_step_args_input_valid(self): """Test validate_step_args_input with valid input""" step_args = _StepArguments(