Skip to content

ModelTrainer (V3) does not support output_kms_key for source code uploads #5956

@LukasSchiffers

Description

@LukasSchiffers

Describe the bug

In SageMaker SDK V2, the Estimator uses the output_kms_key when uploading user training scripts to S3 (see _stage_user_code_in_s3):

def _stage_user_code_in_s3(self) -> UploadedCode:
"""Uploads the user training script to S3 and returns the S3 URI.
Returns: S3 URI
"""
if is_pipeline_variable(self.output_path):
if self.code_location is None:
code_bucket = self.sagemaker_session.default_bucket()
key_prefix = self.sagemaker_session.default_bucket_prefix
code_s3_prefix = self._assign_s3_prefix(key_prefix)
kms_key = None
else:
code_bucket, key_prefix = parse_s3_url(self.code_location)
code_s3_prefix = self._assign_s3_prefix(key_prefix)
output_bucket = self.sagemaker_session.default_bucket()
kms_key = self.output_kms_key if code_bucket == output_bucket else None
else:
local_mode = self.output_path.startswith("file://")
if local_mode:
if self.code_location is None:
code_bucket = self.sagemaker_session.default_bucket()
key_prefix = self.sagemaker_session.default_bucket_prefix
code_s3_prefix = self._assign_s3_prefix(key_prefix)
kms_key = None
else:
code_bucket, key_prefix = parse_s3_url(self.code_location)
code_s3_prefix = self._assign_s3_prefix(key_prefix)
kms_key = None
else:
if self.code_location is None:
code_bucket, possible_key_prefix = parse_s3_url(self.output_path)
if self._is_output_path_set_from_default_bucket_and_prefix:
# Only include possible_key_prefix if the output_path was created from the
# Session's default bucket and prefix. In that scenario, possible_key_prefix
# will either be "" or Session.default_bucket_prefix.
# Note: We cannot do `if (code_bucket == session.default_bucket() and
# key_prefix == session.default_bucket_prefix)` instead because the user
# could have passed in equivalent values themselves to output_path. And
# including the prefix in that case could result in a potentially backwards
# incompatible behavior change for the end user.
code_s3_prefix = self._assign_s3_prefix(possible_key_prefix)
else:
code_s3_prefix = self._assign_s3_prefix()
kms_key = self.output_kms_key
else:
code_bucket, key_prefix = parse_s3_url(self.code_location)
code_s3_prefix = self._assign_s3_prefix(key_prefix)
output_bucket, _ = parse_s3_url(self.output_path)
kms_key = self.output_kms_key if code_bucket == output_bucket else None
return tar_and_upload_dir(
session=self.sagemaker_session.boto_session,
bucket=code_bucket,
s3_key_prefix=code_s3_prefix,
script=self.entry_point,
directory=self.source_dir,
dependencies=self.dependencies,
kms_key=kms_key,
s3_resource=self.sagemaker_session.s3_resource,
settings=self.sagemaker_session.settings,
)

In SageMaker SDK V3, the new ModelTrainer does not apply any KMS key when uploading source code while creating the input data channel:

def create_input_data_channel(
self,
channel_name: str,
data_source: DataSourceType,
key_prefix: Optional[str] = None,
ignore_patterns: Optional[List[str]] = None,
) -> Channel:
"""Create an input data channel for the training job.
Args:
channel_name (str): The name of the input data channel.
data_source (DataSourceType): The data source for the input data channel.
DataSourceType can be an S3 URI string, local file path string,
S3DataSource object, or FileSystemDataSource object.
key_prefix (Optional[str]): The key prefix to use when uploading data to S3.
Only applicable when data_source is a local file path string.
If not specified, local data will be uploaded to:
``s3://<default_bucket_path>/<base_job_name>/input/<channel_name>/``
If specified, local data will be uploaded to:
``s3://<default_bucket_path>/<key_prefix>/<channel_name>/``
ignore_patterns: (Optional[List[str]]) :
The ignore patterns to ignore specific files/folders when uploading to S3.
If not specified, default to: ['.env', '.git', '__pycache__', '.DS_Store', '.cache', '.ipynb_checkpoints'].
"""
from sagemaker.core.helper.pipeline_variable import PipelineVariable
channel = None
if isinstance(data_source, PipelineVariable):
channel = Channel(
channel_name=channel_name,
data_source=DataSource(
s3_data_source=S3DataSource(
s3_data_type="S3Prefix",
s3_uri=data_source,
s3_data_distribution_type="FullyReplicated",
),
),
input_mode="File",
)
elif isinstance(data_source, str):
if _is_valid_s3_uri(data_source):
channel = Channel(
channel_name=channel_name,
data_source=DataSource(
s3_data_source=S3DataSource(
s3_data_type="S3Prefix",
s3_uri=data_source,
s3_data_distribution_type="FullyReplicated",
),
),
input_mode="File",
)
elif _is_valid_path(data_source):
if self.training_mode == Mode.LOCAL_CONTAINER:
channel = Channel(
channel_name=channel_name,
data_source=DataSource(
file_system_data_source=FileSystemDataSource.model_construct(
directory_path=data_source,
file_system_type="EFS",
),
),
input_mode="File",
)
else:
key_prefix = (
f"{key_prefix}/{channel_name}"
if key_prefix
else f"{self.base_job_name}/input/{channel_name}"
)
if self.sagemaker_session.default_bucket_prefix:
key_prefix = f"{self.sagemaker_session.default_bucket_prefix}/{key_prefix}"
if ignore_patterns and _is_valid_path(data_source, path_type="Directory"):
tmp_dir = TemporaryDirectory()
copied_path = os.path.join(
tmp_dir.name, os.path.basename(os.path.normpath(data_source))
)
shutil.copytree(
data_source,
copied_path,
dirs_exist_ok=True,
ignore=shutil.ignore_patterns(*ignore_patterns),
)
s3_uri = self.sagemaker_session.upload_data(
path=copied_path,
bucket=self.sagemaker_session.default_bucket(),
key_prefix=key_prefix,
)
else:
s3_uri = self.sagemaker_session.upload_data(
path=data_source,
bucket=self.sagemaker_session.default_bucket(),
key_prefix=key_prefix,
)
channel = Channel(
channel_name=channel_name,
data_source=DataSource(
s3_data_source=S3DataSource(
s3_data_type="S3Prefix",
s3_uri=s3_uri,
s3_data_distribution_type="FullyReplicated",
),
),
input_mode="File",
)
else:
raise ValueError(f"Not a valid S3 URI or local file path: {data_source}.")
elif isinstance(data_source, S3DataSource):
channel = Channel(
channel_name=channel_name, data_source=DataSource(s3_data_source=data_source)
)
elif isinstance(data_source, FileSystemDataSource):
channel = Channel(
channel_name=channel_name,
data_source=DataSource(file_system_data_source=data_source),
)
else:
raise ValueError(f"Unsupported data_source type: {type(data_source)}")
return channel

Additionally, even when providing an S3 URI in the SourceCode object (instead of a local path), the ModelTrainer still uploads additional driver files whenever source code is specified:

sm_drivers_channel = self.create_input_data_channel(
channel_name=SM_DRIVERS,
data_source=self._temp_code_dir.name,
key_prefix=input_data_key_prefix,
ignore_patterns=self.source_code.ignore_patterns,
)

These uploads do not use a KMS key.

Expected behavior

ModelTrainer should:

  • either respect a user-provided KMS key (similar to output_kms_key in V2), OR
  • allow configuration of a KMS key for all S3 uploads related to source code.

Actual behavior

  • Source code and driver files are uploaded to S3 without KMS encryption.
  • There is no apparent way to configure a KMS key for these uploads.

Impact

In restricted environments (like ours), S3 policies enforce server-side encryption with KMS.
As a result, ModelTrainer cannot be used with custom training scripts.

This blocks use cases that rely on custom code, such as MLflow serverless integration.

Steps to reproduce

  1. Create a ModelTrainer with a SourceCode object
  2. Provide either:
    • a local path, or
    • an S3 URI
  3. Observe that S3 uploads occur without KMS encryption

Possible solution

Expose a parameter similar to output_kms_key in V2, or reuse existing encryption configuration mechanisms.

Additional context

This is a regression compared to V2 Estimator behavior and impacts secure environments with strict S3 encryption policies.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions