-
Notifications
You must be signed in to change notification settings - Fork 1.2k
fix: [Bug] Pipeline parameters (ParameterInteger, ParameterString) fail in ModelTrain (5504) #5689
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -192,7 +192,16 @@ def safe_serialize(data): | |
| try: | ||
| return json.dumps(data) | ||
| except TypeError: | ||
| return str(data) | ||
| try: | ||
| return str(data) | ||
| except TypeError: | ||
| # PipelineVariable.__str__ raises TypeError by design. | ||
aviruthen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # If the isinstance check above didn't catch it (e.g. import | ||
| # path mismatch), fall back to returning the object directly | ||
|
||
| # when it looks like a PipelineVariable (has an ``expr`` property). | ||
| if hasattr(data, "expr"): | ||
| return data | ||
| raise | ||
|
|
||
|
|
||
| def _run_clone_command_silent(repo_url, dest_dir): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,13 +26,14 @@ | |
|
|
||
| from sagemaker.core.helper.session_helper import Session | ||
| from sagemaker.core.helper.pipeline_variable import PipelineVariable, StrPipeVar | ||
| from sagemaker.core.workflow.parameters import ParameterString | ||
| from sagemaker.core.workflow.parameters import ParameterString, ParameterInteger, ParameterFloat | ||
|
||
| from sagemaker.train.model_trainer import ModelTrainer, Mode | ||
| from sagemaker.train.configs import ( | ||
| Compute, | ||
| StoppingCondition, | ||
| OutputDataConfig, | ||
| ) | ||
| from sagemaker.core.workflow.pipeline_context import PipelineSession | ||
| from sagemaker.train.defaults import DEFAULT_INSTANCE_TYPE | ||
|
|
||
|
|
||
|
|
@@ -176,3 +177,61 @@ def test_training_image_rejects_invalid_type(self): | |
| stopping_condition=DEFAULT_STOPPING, | ||
| output_data_config=DEFAULT_OUTPUT, | ||
| ) | ||
|
|
||
|
|
||
| class TestModelTrainerHyperparametersPipelineVariable: | ||
| """Test that PipelineVariable objects in hyperparameters survive safe_serialize.""" | ||
|
||
|
|
||
| def test_hyperparameters_with_pipeline_variable_integer(self): | ||
| """ParameterInteger in hyperparameters should be passed through as-is.""" | ||
| max_depth = ParameterInteger(name="MaxDepth", default_value=5) | ||
| trainer = ModelTrainer( | ||
| training_image=DEFAULT_IMAGE, | ||
| role=DEFAULT_ROLE, | ||
aviruthen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| compute=DEFAULT_COMPUTE, | ||
| stopping_condition=DEFAULT_STOPPING, | ||
| output_data_config=DEFAULT_OUTPUT, | ||
| hyperparameters={"max_depth": max_depth}, | ||
| ) | ||
| # safe_serialize should return the PipelineVariable object directly | ||
|
||
| from sagemaker.train.utils import safe_serialize | ||
| result = safe_serialize(max_depth) | ||
| assert result is max_depth | ||
|
|
||
| def test_hyperparameters_with_pipeline_variable_string(self): | ||
| """ParameterString in hyperparameters should be passed through as-is.""" | ||
| optimizer = ParameterString(name="Optimizer", default_value="sgd") | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These tests verify that |
||
| trainer = ModelTrainer( | ||
| training_image=DEFAULT_IMAGE, | ||
| role=DEFAULT_ROLE, | ||
| compute=DEFAULT_COMPUTE, | ||
| stopping_condition=DEFAULT_STOPPING, | ||
| output_data_config=DEFAULT_OUTPUT, | ||
| hyperparameters={"optimizer": optimizer}, | ||
| ) | ||
| from sagemaker.train.utils import safe_serialize | ||
| result = safe_serialize(optimizer) | ||
| assert result is optimizer | ||
|
|
||
| def test_hyperparameters_with_mixed_pipeline_and_regular_values(self): | ||
| """Mixed PipelineVariable and regular values should both serialize correctly.""" | ||
| max_depth = ParameterInteger(name="MaxDepth", default_value=5) | ||
| trainer = ModelTrainer( | ||
| training_image=DEFAULT_IMAGE, | ||
| role=DEFAULT_ROLE, | ||
| compute=DEFAULT_COMPUTE, | ||
| stopping_condition=DEFAULT_STOPPING, | ||
| output_data_config=DEFAULT_OUTPUT, | ||
| hyperparameters={ | ||
| "max_depth": max_depth, | ||
| "eta": 0.1, | ||
| "objective": "binary:logistic", | ||
| }, | ||
| ) | ||
| from sagemaker.train.utils import safe_serialize | ||
| # PipelineVariable should be returned as-is | ||
| assert safe_serialize(max_depth) is max_depth | ||
| # Float should be JSON-serialized | ||
| assert safe_serialize(0.1) == "0.1" | ||
| # String should be returned as-is | ||
| assert safe_serialize("binary:logistic") == "binary:logistic" | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,90 @@ | ||
| # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"). You | ||
| # may not use this file except in compliance with the License. A copy of | ||
| # the License is located at | ||
| # | ||
| # http://aws.amazon.com/apache2.0/ | ||
| # | ||
| # or in the "license" file accompanying this file. This file is | ||
| # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
| # ANY KIND, either express or implied. See the License for the specific | ||
| # language governing permissions and limitations under the License. | ||
| """Unit tests for sagemaker.train.utils – specifically safe_serialize.""" | ||
| from __future__ import absolute_import | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Per PEP 484 / SDK conventions, new modules should use |
||
|
|
||
| import json | ||
|
|
||
| from sagemaker.train.utils import safe_serialize | ||
| from sagemaker.core.workflow.parameters import ( | ||
| ParameterInteger, | ||
| ParameterString, | ||
| ParameterFloat, | ||
| ) | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # PipelineVariable inputs – should be returned as-is (identity) | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
| def test_safe_serialize_with_pipeline_variable_integer_returns_object_directly(): | ||
| param = ParameterInteger(name="MaxDepth", default_value=5) | ||
| result = safe_serialize(param) | ||
| assert result is param | ||
|
|
||
|
|
||
| def test_safe_serialize_with_pipeline_variable_string_returns_object_directly(): | ||
| param = ParameterString(name="Optimizer", default_value="sgd") | ||
| result = safe_serialize(param) | ||
| assert result is param | ||
|
|
||
|
|
||
| def test_safe_serialize_with_pipeline_variable_float_returns_object_directly(): | ||
| param = ParameterFloat(name="LearningRate", default_value=0.01) | ||
| result = safe_serialize(param) | ||
| assert result is param | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Regular / primitive inputs | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
| def test_safe_serialize_with_string_returns_string_as_is(): | ||
| assert safe_serialize("hello") == "hello" | ||
| assert safe_serialize("12345") == "12345" | ||
|
|
||
|
|
||
| def test_safe_serialize_with_int_returns_json_string(): | ||
aviruthen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| assert safe_serialize(5) == "5" | ||
| assert safe_serialize(0) == "0" | ||
|
|
||
|
|
||
| def test_safe_serialize_with_dict_returns_json_string(): | ||
| data = {"key": "value", "num": 1} | ||
| assert safe_serialize(data) == json.dumps(data) | ||
|
|
||
|
|
||
| def test_safe_serialize_with_bool_returns_json_string(): | ||
| assert safe_serialize(True) == "true" | ||
| assert safe_serialize(False) == "false" | ||
|
|
||
|
|
||
| def test_safe_serialize_with_custom_object_returns_str(): | ||
| class CustomObject: | ||
| def __str__(self): | ||
| return "CustomObject" | ||
|
|
||
| obj = CustomObject() | ||
| assert safe_serialize(obj) == "CustomObject" | ||
|
|
||
|
|
||
| def test_safe_serialize_with_none_returns_json_null(): | ||
| assert safe_serialize(None) == "null" | ||
|
|
||
|
|
||
| def test_safe_serialize_with_list_returns_json_string(): | ||
| assert safe_serialize([1, 2, 3]) == "[1, 2, 3]" | ||
|
|
||
|
|
||
| def test_safe_serialize_with_empty_string(): | ||
| assert safe_serialize("") == "" | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good test coverage overall! Consider adding a test for the specific edge case this PR is fixing — where def test_safe_serialize_with_object_whose_str_raises_typeerror():
"""Objects whose __str__ raises TypeError should be returned as-is."""
class BadStr:
def __str__(self):
raise TypeError("cannot convert")
obj = BadStr()
result = safe_serialize(obj)
assert result is objThis directly tests the new fallback code path added in this PR. |
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The nested try/except is a reasonable defensive measure, but catching a bare
TypeErrorfromstr(data)and silently returning the raw object is risky. This means any object whose__str__raisesTypeError(not justPipelineVariable) will be returned as-is, potentially causing unexpected behavior downstream. Consider being more explicit:This way, only known
PipelineVariableobjects get the pass-through treatment, and truly broken objects still raise.