|
def create_input_data_channel( |
|
self, |
|
channel_name: str, |
|
data_source: DataSourceType, |
|
key_prefix: Optional[str] = None, |
|
ignore_patterns: Optional[List[str]] = None, |
|
) -> Channel: |
|
"""Create an input data channel for the training job. |
|
|
|
Args: |
|
channel_name (str): The name of the input data channel. |
|
data_source (DataSourceType): The data source for the input data channel. |
|
DataSourceType can be an S3 URI string, local file path string, |
|
S3DataSource object, or FileSystemDataSource object. |
|
key_prefix (Optional[str]): The key prefix to use when uploading data to S3. |
|
Only applicable when data_source is a local file path string. |
|
If not specified, local data will be uploaded to: |
|
``s3://<default_bucket_path>/<base_job_name>/input/<channel_name>/`` |
|
|
|
If specified, local data will be uploaded to: |
|
``s3://<default_bucket_path>/<key_prefix>/<channel_name>/`` |
|
ignore_patterns: (Optional[List[str]]) : |
|
The ignore patterns to ignore specific files/folders when uploading to S3. |
|
If not specified, default to: ['.env', '.git', '__pycache__', '.DS_Store', '.cache', '.ipynb_checkpoints']. |
|
""" |
|
from sagemaker.core.helper.pipeline_variable import PipelineVariable |
|
|
|
channel = None |
|
if isinstance(data_source, PipelineVariable): |
|
channel = Channel( |
|
channel_name=channel_name, |
|
data_source=DataSource( |
|
s3_data_source=S3DataSource( |
|
s3_data_type="S3Prefix", |
|
s3_uri=data_source, |
|
s3_data_distribution_type="FullyReplicated", |
|
), |
|
), |
|
input_mode="File", |
|
) |
|
elif isinstance(data_source, str): |
|
if _is_valid_s3_uri(data_source): |
|
channel = Channel( |
|
channel_name=channel_name, |
|
data_source=DataSource( |
|
s3_data_source=S3DataSource( |
|
s3_data_type="S3Prefix", |
|
s3_uri=data_source, |
|
s3_data_distribution_type="FullyReplicated", |
|
), |
|
), |
|
input_mode="File", |
|
) |
|
elif _is_valid_path(data_source): |
|
if self.training_mode == Mode.LOCAL_CONTAINER: |
|
channel = Channel( |
|
channel_name=channel_name, |
|
data_source=DataSource( |
|
file_system_data_source=FileSystemDataSource.model_construct( |
|
directory_path=data_source, |
|
file_system_type="EFS", |
|
), |
|
), |
|
input_mode="File", |
|
) |
|
else: |
|
key_prefix = ( |
|
f"{key_prefix}/{channel_name}" |
|
if key_prefix |
|
else f"{self.base_job_name}/input/{channel_name}" |
|
) |
|
if self.sagemaker_session.default_bucket_prefix: |
|
key_prefix = f"{self.sagemaker_session.default_bucket_prefix}/{key_prefix}" |
|
if ignore_patterns and _is_valid_path(data_source, path_type="Directory"): |
|
tmp_dir = TemporaryDirectory() |
|
copied_path = os.path.join( |
|
tmp_dir.name, os.path.basename(os.path.normpath(data_source)) |
|
) |
|
shutil.copytree( |
|
data_source, |
|
copied_path, |
|
dirs_exist_ok=True, |
|
ignore=shutil.ignore_patterns(*ignore_patterns), |
|
) |
|
s3_uri = self.sagemaker_session.upload_data( |
|
path=copied_path, |
|
bucket=self.sagemaker_session.default_bucket(), |
|
key_prefix=key_prefix, |
|
) |
|
else: |
|
s3_uri = self.sagemaker_session.upload_data( |
|
path=data_source, |
|
bucket=self.sagemaker_session.default_bucket(), |
|
key_prefix=key_prefix, |
|
) |
|
channel = Channel( |
|
channel_name=channel_name, |
|
data_source=DataSource( |
|
s3_data_source=S3DataSource( |
|
s3_data_type="S3Prefix", |
|
s3_uri=s3_uri, |
|
s3_data_distribution_type="FullyReplicated", |
|
), |
|
), |
|
input_mode="File", |
|
) |
|
else: |
|
raise ValueError(f"Not a valid S3 URI or local file path: {data_source}.") |
|
elif isinstance(data_source, S3DataSource): |
|
channel = Channel( |
|
channel_name=channel_name, data_source=DataSource(s3_data_source=data_source) |
|
) |
|
elif isinstance(data_source, FileSystemDataSource): |
|
channel = Channel( |
|
channel_name=channel_name, |
|
data_source=DataSource(file_system_data_source=data_source), |
|
) |
|
else: |
|
raise ValueError(f"Unsupported data_source type: {type(data_source)}") |
|
return channel |
Describe the bug
In SageMaker SDK V2, the
Estimatoruses theoutput_kms_keywhen uploading user training scripts to S3 (see_stage_user_code_in_s3):sagemaker-python-sdk/src/sagemaker/estimator.py
Lines 1044 to 1108 in 5b3b127
In SageMaker SDK V3, the new
ModelTrainerdoes not apply any KMS key when uploading source code while creating the input data channel:sagemaker-python-sdk/sagemaker-train/src/sagemaker/train/model_trainer.py
Lines 834 to 953 in 9101cef
Additionally, even when providing an S3 URI in the
SourceCodeobject (instead of a local path), theModelTrainerstill uploads additional driver files whenever source code is specified:sagemaker-python-sdk/sagemaker-train/src/sagemaker/train/model_trainer.py
Lines 684 to 689 in 9101cef
These uploads do not use a KMS key.
Expected behavior
ModelTrainershould:output_kms_keyin V2), ORActual behavior
Impact
In restricted environments (like ours), S3 policies enforce server-side encryption with KMS.
As a result,
ModelTrainercannot be used with custom training scripts.This blocks use cases that rely on custom code, such as MLflow serverless integration.
Steps to reproduce
ModelTrainerwith aSourceCodeobjectPossible solution
Expose a parameter similar to
output_kms_keyin V2, or reuse existing encryption configuration mechanisms.Additional context
This is a regression compared to V2
Estimatorbehavior and impacts secure environments with strict S3 encryption policies.