Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions sagemaker-core/tests/unit/workflow/test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from sagemaker.core.workflow.entities import Entity
from sagemaker.core.workflow.parameters import Parameter
from sagemaker.core.workflow.pipeline_context import _StepArguments
from sagemaker.core.workflow.utilities import get_code_hash


class MockEntity(Entity):
Expand Down Expand Up @@ -308,6 +309,49 @@ def test_get_training_code_hash_entry_point_only(self):
assert len(result_with_deps) == 64
assert result_no_deps != result_with_deps

def test_get_training_code_hash_with_source_dir_none_dependencies(self):
"""Test get_training_code_hash with source_dir and None dependencies does not raise TypeError"""
with tempfile.TemporaryDirectory() as temp_dir:
entry_file = Path(temp_dir, "train.py")
entry_file.write_text("print('training')")

# This is the exact scenario from the bug report: dependencies=None
result = get_training_code_hash(
entry_point=str(entry_file), source_dir=temp_dir, dependencies=None
)

assert result is not None
assert len(result) == 64

def test_get_training_code_hash_entry_point_only_none_dependencies(self):
"""Test get_training_code_hash with entry_point only and None dependencies does not raise TypeError"""
with tempfile.TemporaryDirectory() as temp_dir:
entry_file = Path(temp_dir, "train.py")
entry_file.write_text("print('training')")

# entry_point only, no source_dir, dependencies=None
result = get_training_code_hash(
entry_point=str(entry_file), source_dir=None, dependencies=None
)

assert result is not None
assert len(result) == 64

def test_get_training_code_hash_default_dependencies(self):
"""Test get_training_code_hash with default dependencies parameter (not passed)"""
with tempfile.TemporaryDirectory() as temp_dir:
entry_file = Path(temp_dir, "train.py")
entry_file.write_text("print('training')")

# Not passing dependencies at all - should use default None
result = get_training_code_hash(
entry_point=str(entry_file), source_dir=temp_dir
)

assert result is not None
assert len(result) == 64


def test_get_training_code_hash_s3_uri(self):
"""Test get_training_code_hash with S3 URI returns None"""
result = get_training_code_hash(
Expand All @@ -325,6 +369,54 @@ def test_get_training_code_hash_pipeline_variable(self):

assert result is None

@pytest.mark.skip(reason="Requires sagemaker-mlops module which is not installed in sagemaker-core tests")
def test_get_code_hash_training_step_none_requirements(self):
"""Test get_code_hash with TrainingStep where source_code.requirements is None"""
from sagemaker.mlops.workflow.steps import TrainingStep

with tempfile.TemporaryDirectory() as temp_dir:
entry_file = Path(temp_dir, "train.py")
entry_file.write_text("print('training')")

mock_source_code = Mock()
mock_source_code.source_dir = temp_dir
mock_source_code.requirements = None # This is the bug scenario
mock_source_code.entry_script = str(entry_file)

mock_model_trainer = Mock()
mock_model_trainer.source_code = mock_source_code

mock_step_args = Mock(spec=_StepArguments)
mock_step_args.func_args = [mock_model_trainer]

mock_step = Mock(spec=TrainingStep)
mock_step.step_args = mock_step_args

# This should not raise TypeError
result = get_code_hash(mock_step)

assert result is not None
assert len(result) == 64

@pytest.mark.skip(reason="Requires sagemaker-mlops module which is not installed in sagemaker-core tests")
def test_get_code_hash_training_step_no_source_code(self):
"""Test get_code_hash with TrainingStep where source_code is None"""
from sagemaker.mlops.workflow.steps import TrainingStep

mock_model_trainer = Mock()
mock_model_trainer.source_code = None

mock_step_args = Mock(spec=_StepArguments)
mock_step_args.func_args = [mock_model_trainer]

mock_step = Mock(spec=TrainingStep)
mock_step.step_args = mock_step_args

result = get_code_hash(mock_step)

assert result is None


def test_validate_step_args_input_valid(self):
"""Test validate_step_args_input with valid input"""
step_args = _StepArguments(
Expand Down
Loading