Skip to content

Commit 0906526

Browse files
amitmodiAmit Modi
andauthored
fix: Add PipelineVariable support to ModelTrainer fields (fixes #5524) (#5608)
* fix: Add PipelineVariable support to ModelTrainer fields (fixes #5524) Extend StrPipeVar type to ModelTrainer's direct fields: - training_image: Optional[str] -> Optional[StrPipeVar] - algorithm_name: Optional[str] -> Optional[StrPipeVar] - training_input_mode: Optional[str] -> Optional[StrPipeVar] - environment: Dict[str, str] -> Dict[str, StrPipeVar] This follows the existing V3 pattern already used by SourceCode, OutputDataConfig, and Compute (for instance_type). The StrPipeVar type alias and PipelineVariable.__get_pydantic_core_schema__() already exist in the codebase. This unblocks V2->V3 migration for SageMaker Pipelines users who need to pass ParameterString to ModelTrainer fields. Fixes #5524 * test: Add unit tests for PipelineVariable support + fix PipelineVariable-safe logging - Add test_model_trainer_pipeline_variable.py with 9 tests: - 4 PipelineVariable acceptance tests (training_image, algorithm_name, training_input_mode, environment) - 4 regression tests (real string values still work) - 1 invalid type rejection test - Fix PipelineVariable-safe logging in model_post_init (avoid __str__ on PipelineVariable which raises TypeError) All 57 tests pass (48 existing + 9 new, 0 regressions). --------- Co-authored-by: Amit Modi <modiamit@amazon.com>
1 parent 55a4ee5 commit 0906526

File tree

2 files changed

+188
-5
lines changed

2 files changed

+188
-5
lines changed

sagemaker-train/src/sagemaker/train/model_trainer.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@
116116
from sagemaker.core.jumpstart.utils import get_eula_url
117117
from sagemaker.train.defaults import TrainDefaults, JumpStartTrainDefaults
118118
from sagemaker.core.workflow.pipeline_context import PipelineSession, runnable_by_pipeline
119+
from sagemaker.core.helper.pipeline_variable import StrPipeVar
119120

120121
from sagemaker.train.local.local_container import _LocalContainer
121122

@@ -235,14 +236,14 @@ class ModelTrainer(BaseModel):
235236
compute: Optional[Compute] = None
236237
networking: Optional[Networking] = None
237238
stopping_condition: Optional[StoppingCondition] = None
238-
training_image: Optional[str] = None
239+
training_image: Optional[StrPipeVar] = None
239240
training_image_config: Optional[TrainingImageConfig] = None
240-
algorithm_name: Optional[str] = None
241+
algorithm_name: Optional[StrPipeVar] = None
241242
output_data_config: Optional[shapes.OutputDataConfig] = None
242243
input_data_config: Optional[List[Union[Channel, InputData]]] = None
243244
checkpoint_config: Optional[shapes.CheckpointConfig] = None
244-
training_input_mode: Optional[str] = "File"
245-
environment: Optional[Dict[str, str]] = {}
245+
training_input_mode: Optional[StrPipeVar] = "File"
246+
environment: Optional[Dict[str, StrPipeVar]] = {}
246247
hyperparameters: Optional[Union[Dict[str, Any], str]] = {}
247248
tags: Optional[List[Tag]] = None
248249
local_container_root: Optional[str] = os.getcwd()
@@ -545,7 +546,11 @@ def model_post_init(self, __context: Any):
545546
)
546547

547548
if self.training_image:
548-
logger.info(f"Training image URI: {self.training_image}")
549+
from sagemaker.core.helper.pipeline_variable import PipelineVariable
550+
if isinstance(self.training_image, PipelineVariable):
551+
logger.info("Training image URI: (PipelineVariable - resolved at pipeline execution)")
552+
else:
553+
logger.info(f"Training image URI: {self.training_image}")
549554

550555

551556
def _create_training_job_args(
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Tests for PipelineVariable support in ModelTrainer (GH#5524).
14+
15+
Verifies that ModelTrainer fields accept PipelineVariable objects
16+
(e.g., ParameterString) in addition to their concrete types, following
17+
the existing V3 pattern established by SourceCode and OutputDataConfig.
18+
19+
See: https://github.com/aws/sagemaker-python-sdk/issues/5524
20+
"""
21+
from __future__ import absolute_import
22+
23+
import pytest
24+
from pydantic import ValidationError
25+
from unittest.mock import patch, MagicMock
26+
27+
from sagemaker.core.helper.session_helper import Session
28+
from sagemaker.core.helper.pipeline_variable import PipelineVariable, StrPipeVar
29+
from sagemaker.core.workflow.parameters import ParameterString
30+
from sagemaker.train.model_trainer import ModelTrainer, Mode
31+
from sagemaker.train.configs import (
32+
Compute,
33+
StoppingCondition,
34+
OutputDataConfig,
35+
)
36+
from sagemaker.train.defaults import DEFAULT_INSTANCE_TYPE
37+
38+
39+
DEFAULT_IMAGE = "000000000000.dkr.ecr.us-west-2.amazonaws.com/dummy-image:latest"
40+
DEFAULT_BUCKET = "sagemaker-us-west-2-000000000000"
41+
DEFAULT_ROLE = "arn:aws:iam::000000000000:role/test-role"
42+
DEFAULT_BUCKET_PREFIX = "sample-prefix"
43+
DEFAULT_REGION = "us-west-2"
44+
DEFAULT_COMPUTE = Compute(instance_type=DEFAULT_INSTANCE_TYPE, instance_count=1)
45+
DEFAULT_STOPPING = StoppingCondition(max_runtime_in_seconds=3600)
46+
DEFAULT_OUTPUT = OutputDataConfig(
47+
s3_output_path=f"s3://{DEFAULT_BUCKET}/{DEFAULT_BUCKET_PREFIX}/test-job",
48+
)
49+
50+
51+
@pytest.fixture(scope="module", autouse=True)
52+
def modules_session():
53+
with patch("sagemaker.train.Session", spec=Session) as session_mock:
54+
session_instance = session_mock.return_value
55+
session_instance.default_bucket.return_value = DEFAULT_BUCKET
56+
session_instance.get_caller_identity_arn.return_value = DEFAULT_ROLE
57+
session_instance.default_bucket_prefix = DEFAULT_BUCKET_PREFIX
58+
session_instance.boto_session = MagicMock(spec="boto3.session.Session")
59+
session_instance.boto_region_name = DEFAULT_REGION
60+
yield session_instance
61+
62+
63+
class TestModelTrainerPipelineVariableAcceptance:
64+
"""Test that ModelTrainer fields accept PipelineVariable objects."""
65+
66+
def test_training_image_accepts_parameter_string(self):
67+
"""ModelTrainer.training_image should accept ParameterString (GH#5524)."""
68+
param = ParameterString(name="TrainingImage", default_value=DEFAULT_IMAGE)
69+
trainer = ModelTrainer(
70+
training_image=param,
71+
base_job_name="pipeline-test-job", # Required: PipelineVariable can't generate job name
72+
role=DEFAULT_ROLE,
73+
compute=DEFAULT_COMPUTE,
74+
stopping_condition=DEFAULT_STOPPING,
75+
output_data_config=DEFAULT_OUTPUT,
76+
)
77+
assert trainer.training_image is param
78+
79+
def test_algorithm_name_accepts_parameter_string(self):
80+
"""ModelTrainer.algorithm_name should accept ParameterString."""
81+
param = ParameterString(name="AlgorithmName", default_value="my-algo-arn")
82+
trainer = ModelTrainer(
83+
algorithm_name=param,
84+
base_job_name="pipeline-test-job", # Required: PipelineVariable can't generate job name
85+
role=DEFAULT_ROLE,
86+
compute=DEFAULT_COMPUTE,
87+
stopping_condition=DEFAULT_STOPPING,
88+
output_data_config=DEFAULT_OUTPUT,
89+
)
90+
assert trainer.algorithm_name is param
91+
92+
def test_training_input_mode_accepts_parameter_string(self):
93+
"""ModelTrainer.training_input_mode should accept ParameterString."""
94+
param = ParameterString(name="InputMode", default_value="File")
95+
trainer = ModelTrainer(
96+
training_image=DEFAULT_IMAGE,
97+
training_input_mode=param,
98+
role=DEFAULT_ROLE,
99+
compute=DEFAULT_COMPUTE,
100+
stopping_condition=DEFAULT_STOPPING,
101+
output_data_config=DEFAULT_OUTPUT,
102+
)
103+
assert trainer.training_input_mode is param
104+
105+
def test_environment_values_accept_parameter_string(self):
106+
"""ModelTrainer.environment dict values should accept ParameterString."""
107+
param = ParameterString(name="DatasetVersion", default_value="v1")
108+
trainer = ModelTrainer(
109+
training_image=DEFAULT_IMAGE,
110+
environment={"DATASET_VERSION": param, "STATIC_VAR": "hello"},
111+
role=DEFAULT_ROLE,
112+
compute=DEFAULT_COMPUTE,
113+
stopping_condition=DEFAULT_STOPPING,
114+
output_data_config=DEFAULT_OUTPUT,
115+
)
116+
assert trainer.environment["DATASET_VERSION"] is param
117+
assert trainer.environment["STATIC_VAR"] == "hello"
118+
119+
120+
class TestModelTrainerRealValuesStillWork:
121+
"""Regression tests: verify that passing real values still works after the change."""
122+
123+
def test_training_image_accepts_real_string(self):
124+
"""ModelTrainer.training_image should still accept a plain string."""
125+
trainer = ModelTrainer(
126+
training_image=DEFAULT_IMAGE,
127+
role=DEFAULT_ROLE,
128+
compute=DEFAULT_COMPUTE,
129+
stopping_condition=DEFAULT_STOPPING,
130+
output_data_config=DEFAULT_OUTPUT,
131+
)
132+
assert trainer.training_image == DEFAULT_IMAGE
133+
134+
def test_algorithm_name_accepts_real_string(self):
135+
"""ModelTrainer.algorithm_name should still accept a plain string."""
136+
trainer = ModelTrainer(
137+
algorithm_name="arn:aws:sagemaker:us-west-2:000000000000:algorithm/my-algo",
138+
role=DEFAULT_ROLE,
139+
compute=DEFAULT_COMPUTE,
140+
stopping_condition=DEFAULT_STOPPING,
141+
output_data_config=DEFAULT_OUTPUT,
142+
)
143+
assert trainer.algorithm_name == "arn:aws:sagemaker:us-west-2:000000000000:algorithm/my-algo"
144+
145+
def test_training_input_mode_accepts_real_string(self):
146+
"""ModelTrainer.training_input_mode should still accept a plain string."""
147+
trainer = ModelTrainer(
148+
training_image=DEFAULT_IMAGE,
149+
training_input_mode="Pipe",
150+
role=DEFAULT_ROLE,
151+
compute=DEFAULT_COMPUTE,
152+
stopping_condition=DEFAULT_STOPPING,
153+
output_data_config=DEFAULT_OUTPUT,
154+
)
155+
assert trainer.training_input_mode == "Pipe"
156+
157+
def test_environment_accepts_real_string_values(self):
158+
"""ModelTrainer.environment should still accept plain string values."""
159+
trainer = ModelTrainer(
160+
training_image=DEFAULT_IMAGE,
161+
environment={"KEY1": "value1", "KEY2": "value2"},
162+
role=DEFAULT_ROLE,
163+
compute=DEFAULT_COMPUTE,
164+
stopping_condition=DEFAULT_STOPPING,
165+
output_data_config=DEFAULT_OUTPUT,
166+
)
167+
assert trainer.environment == {"KEY1": "value1", "KEY2": "value2"}
168+
169+
def test_training_image_rejects_invalid_type(self):
170+
"""ModelTrainer.training_image should still reject invalid types (e.g., int)."""
171+
with pytest.raises(ValidationError):
172+
ModelTrainer(
173+
training_image=12345,
174+
role=DEFAULT_ROLE,
175+
compute=DEFAULT_COMPUTE,
176+
stopping_condition=DEFAULT_STOPPING,
177+
output_data_config=DEFAULT_OUTPUT,
178+
)

0 commit comments

Comments
 (0)