Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion sagemaker-train/src/sagemaker/train/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,16 @@ def safe_serialize(data):
try:
return json.dumps(data)
except TypeError:
return str(data)
try:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The nested try/except is a reasonable defensive measure, but catching a bare TypeError from str(data) and silently returning the raw object is risky. This means any object whose __str__ raises TypeError (not just PipelineVariable) will be returned as-is, potentially causing unexpected behavior downstream. Consider being more explicit:

except TypeError:
    try:
        return str(data)
    except TypeError:
        if isinstance(data, PipelineVariable):
            return data
        raise

This way, only known PipelineVariable objects get the pass-through treatment, and truly broken objects still raise.

return str(data)
except TypeError:
# PipelineVariable.__str__ raises TypeError by design.
# If the isinstance check above didn't catch it (e.g. import
# path mismatch), fall back to returning the object directly
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using hasattr(data, 'expr') as a duck-typing check for PipelineVariable is fragile — any object with an expr attribute would match this condition. Since there's already a PipelineVariable isinstance check earlier in the function (lines 185-186), a more robust approach would be to repeat that isinstance check here, or better yet, import and check against the specific parameter types. This would also be more maintainable.

Alternatively, consider catching TypeError more broadly:

except TypeError:
    try:
        return str(data)
    except TypeError:
        # PipelineVariable.__str__ raises TypeError by design.
        # Return the object directly so pipeline serialization can handle it.
        return data

Returning data unconditionally in the inner except is safer than a duck-type check that could silently pass through unrelated objects. If the object truly can't be serialized or stringified, it's likely a PipelineVariable that should be passed through. If you want to be defensive, re-use the PipelineVariable isinstance check instead of hasattr.

# when it looks like a PipelineVariable (has an ``expr`` property).
if hasattr(data, "expr"):
return data
raise


def _run_clone_command_silent(repo_url, dest_dir):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,14 @@

from sagemaker.core.helper.session_helper import Session
from sagemaker.core.helper.pipeline_variable import PipelineVariable, StrPipeVar
from sagemaker.core.workflow.parameters import ParameterString
from sagemaker.core.workflow.parameters import ParameterString, ParameterInteger, ParameterFloat
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ParameterFloat and PipelineSession are imported but ParameterFloat is never used in the new tests, and PipelineSession is also unused. Remove unused imports.

from sagemaker.train.model_trainer import ModelTrainer, Mode
from sagemaker.train.configs import (
Compute,
StoppingCondition,
OutputDataConfig,
)
from sagemaker.core.workflow.pipeline_context import PipelineSession
from sagemaker.train.defaults import DEFAULT_INSTANCE_TYPE


Expand Down Expand Up @@ -176,3 +177,61 @@ def test_training_image_rejects_invalid_type(self):
stopping_condition=DEFAULT_STOPPING,
output_data_config=DEFAULT_OUTPUT,
)


class TestModelTrainerHyperparametersPipelineVariable:
"""Test that PipelineVariable objects in hyperparameters survive safe_serialize."""
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests in test_model_trainer_pipeline_variable.py are essentially duplicates of the tests already added in test_utils.py. The safe_serialize unit tests in test_utils.py are sufficient for testing the serialization behavior. These tests should instead verify the integration — that ModelTrainer correctly preserves PipelineVariable objects in its hyperparameters dict after construction, rather than just re-testing safe_serialize.


def test_hyperparameters_with_pipeline_variable_integer(self):
"""ParameterInteger in hyperparameters should be passed through as-is."""
max_depth = ParameterInteger(name="MaxDepth", default_value=5)
trainer = ModelTrainer(
training_image=DEFAULT_IMAGE,
role=DEFAULT_ROLE,
compute=DEFAULT_COMPUTE,
stopping_condition=DEFAULT_STOPPING,
output_data_config=DEFAULT_OUTPUT,
hyperparameters={"max_depth": max_depth},
)
# safe_serialize should return the PipelineVariable object directly
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imports should be at the top of the file, not inside test functions. Move from sagemaker.train.utils import safe_serialize to the module-level imports. This applies to all three new test methods.

from sagemaker.train.utils import safe_serialize
result = safe_serialize(max_depth)
assert result is max_depth

def test_hyperparameters_with_pipeline_variable_string(self):
"""ParameterString in hyperparameters should be passed through as-is."""
optimizer = ParameterString(name="Optimizer", default_value="sgd")
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests verify that ModelTrainer preserves PipelineVariable objects in the hyperparameters dict, but they don't test the actual serialization path (i.e., when safe_serialize is called during job creation). Consider adding a test that mocks the training job creation to verify that PipelineVariable values survive the full serialization pipeline, not just assignment.

trainer = ModelTrainer(
training_image=DEFAULT_IMAGE,
role=DEFAULT_ROLE,
compute=DEFAULT_COMPUTE,
stopping_condition=DEFAULT_STOPPING,
output_data_config=DEFAULT_OUTPUT,
hyperparameters={"optimizer": optimizer},
)
from sagemaker.train.utils import safe_serialize
result = safe_serialize(optimizer)
assert result is optimizer

def test_hyperparameters_with_mixed_pipeline_and_regular_values(self):
"""Mixed PipelineVariable and regular values should both serialize correctly."""
max_depth = ParameterInteger(name="MaxDepth", default_value=5)
trainer = ModelTrainer(
training_image=DEFAULT_IMAGE,
role=DEFAULT_ROLE,
compute=DEFAULT_COMPUTE,
stopping_condition=DEFAULT_STOPPING,
output_data_config=DEFAULT_OUTPUT,
hyperparameters={
"max_depth": max_depth,
"eta": 0.1,
"objective": "binary:logistic",
},
)
from sagemaker.train.utils import safe_serialize
# PipelineVariable should be returned as-is
assert safe_serialize(max_depth) is max_depth
# Float should be JSON-serialized
assert safe_serialize(0.1) == "0.1"
# String should be returned as-is
assert safe_serialize("binary:logistic") == "binary:logistic"
90 changes: 90 additions & 0 deletions sagemaker-train/tests/unit/train/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Unit tests for sagemaker.train.utils – specifically safe_serialize."""
from __future__ import absolute_import
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per PEP 484 / SDK conventions, new modules should use from __future__ import annotations instead of from __future__ import absolute_import (which is a Python 2 artifact and unnecessary in Python 3).


import json

from sagemaker.train.utils import safe_serialize
from sagemaker.core.workflow.parameters import (
ParameterInteger,
ParameterString,
ParameterFloat,
)


# ---------------------------------------------------------------------------
# PipelineVariable inputs – should be returned as-is (identity)
# ---------------------------------------------------------------------------

def test_safe_serialize_with_pipeline_variable_integer_returns_object_directly():
param = ParameterInteger(name="MaxDepth", default_value=5)
result = safe_serialize(param)
assert result is param


def test_safe_serialize_with_pipeline_variable_string_returns_object_directly():
param = ParameterString(name="Optimizer", default_value="sgd")
result = safe_serialize(param)
assert result is param


def test_safe_serialize_with_pipeline_variable_float_returns_object_directly():
param = ParameterFloat(name="LearningRate", default_value=0.01)
result = safe_serialize(param)
assert result is param


# ---------------------------------------------------------------------------
# Regular / primitive inputs
# ---------------------------------------------------------------------------

def test_safe_serialize_with_string_returns_string_as_is():
assert safe_serialize("hello") == "hello"
assert safe_serialize("12345") == "12345"


def test_safe_serialize_with_int_returns_json_string():
assert safe_serialize(5) == "5"
assert safe_serialize(0) == "0"


def test_safe_serialize_with_dict_returns_json_string():
data = {"key": "value", "num": 1}
assert safe_serialize(data) == json.dumps(data)


def test_safe_serialize_with_bool_returns_json_string():
assert safe_serialize(True) == "true"
assert safe_serialize(False) == "false"


def test_safe_serialize_with_custom_object_returns_str():
class CustomObject:
def __str__(self):
return "CustomObject"

obj = CustomObject()
assert safe_serialize(obj) == "CustomObject"


def test_safe_serialize_with_none_returns_json_null():
assert safe_serialize(None) == "null"


def test_safe_serialize_with_list_returns_json_string():
assert safe_serialize([1, 2, 3]) == "[1, 2, 3]"


def test_safe_serialize_with_empty_string():
assert safe_serialize("") == ""
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good test coverage overall! Consider adding a test for the specific edge case this PR is fixing — where isinstance check for PipelineVariable might fail (e.g., import path mismatch). You could mock isinstance or create a mock object whose __str__ raises TypeError to verify the nested except path:

def test_safe_serialize_with_object_whose_str_raises_typeerror():
    """Objects whose __str__ raises TypeError should be returned as-is."""
    class BadStr:
        def __str__(self):
            raise TypeError("cannot convert")
    obj = BadStr()
    result = safe_serialize(obj)
    assert result is obj

This directly tests the new fallback code path added in this PR.

Loading