Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion roboflow/cli/handlers/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,10 @@ def upload_alias(
project: Annotated[str, typer.Option("-p", "--project", help="Project ID")],
annotation: Annotated[Optional[str], typer.Option("-a", "--annotation", help="Annotation file")] = None,
labelmap: Annotated[Optional[str], typer.Option("-m", "--labelmap", help="Labelmap file")] = None,
split: Annotated[str, typer.Option("-s", "--split", help="Split (train/valid/test)")] = "train",
split: Annotated[
Optional[str],
typer.Option("-s", "--split", help="Override split for all uploaded images (default: infer from folder)"),
] = None,
num_retries: Annotated[int, typer.Option("-r", "--retries", help="Retry count")] = 0,
batch: Annotated[Optional[str], typer.Option("-b", "--batch", help="Batch name")] = None,
tag_names: Annotated[Optional[str], typer.Option("-t", "--tag", help="Tag names")] = None,
Expand Down
11 changes: 9 additions & 2 deletions roboflow/cli/handlers/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,14 @@ def upload_image(
annotation: Annotated[
Optional[str], typer.Option("-a", "--annotation", help="Path to annotation file (single upload)")
] = None,
split: Annotated[str, typer.Option("-s", "--split", help="Dataset split")] = "train",
split: Annotated[
Optional[str],
typer.Option(
"-s",
"--split",
help="Override split for all images (default: infer from folder for dirs, 'train' for files)",
),
] = None,
batch: Annotated[Optional[str], typer.Option("-b", "--batch", help="Batch name")] = None,
tag: Annotated[Optional[str], typer.Option("-t", "--tag", help="Comma-separated tag names")] = None,
metadata: Annotated[Optional[str], typer.Option(help="JSON string of key-value metadata")] = None,
Expand Down Expand Up @@ -237,7 +244,7 @@ def _handle_upload_single(args, api_key: str, path: str) -> None: # noqa: ANN00
image_path=path,
annotation_path=args.annotation,
annotation_labelmap=getattr(args, "labelmap", None),
split=args.split,
split=args.split or "train",
num_retry_uploads=retries,
batch_name=args.batch,
tag_names=tag_names,
Expand Down
6 changes: 5 additions & 1 deletion roboflow/core/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,8 @@ def upload_dataset(
is_prediction (bool, optional): whether the annotations provided in the dataset are predictions and not ground truth. Defaults to False.
use_zip_upload (bool, optional): opt-in to the zip flow for a directory input (the SDK zips it client-side). Ignored when dataset_path is already a `.zip`.
tags (list[str], optional): zip flow only — tags to apply to the uploaded batch.
split (str, optional): zip flow only — dataset split for the uploaded batch.
split (str, optional): dataset split for the uploaded batch. In per-image directory
uploads, this overrides inferred splits for every image.
wait (bool, optional): zip flow only — poll for processing completion. Defaults to True.
poll_interval (float, optional): zip flow only — seconds between status polls.
poll_timeout (float, optional): zip flow only — total seconds to wait before timing out.
Expand Down Expand Up @@ -489,6 +490,9 @@ def upload_dataset(
is_classification = project.type == "classification"
parsed_dataset = folderparser.parsefolder(dataset_path, is_classification=is_classification)
images = parsed_dataset["images"]
if split is not None:
for image in images:
image["split"] = split

location = parsed_dataset["location"]

Expand Down
68 changes: 68 additions & 0 deletions tests/cli/test_image_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,74 @@ def test_zip_upload_flag_defaults_false(self, mock_rf_cls):
_, kwargs = mock_ws.upload_dataset.call_args
self.assertEqual(kwargs.get("use_zip_upload"), False)

@patch("roboflow.Roboflow")
def test_upload_directory_omits_default_split_when_not_explicit(self, mock_rf_cls):
from roboflow.cli.handlers.image import _handle_upload

with tempfile.TemporaryDirectory() as tmpdir:
mock_ws = MagicMock()
mock_rf_cls.return_value.workspace.return_value = mock_ws

args = _make_args(
json=True,
path=tmpdir,
project="proj",
annotation=None,
split=None,
batch=None,
tag=None,
metadata=None,
concurrency=10,
retries=0,
labelmap=None,
is_prediction=False,
)

buf = io.StringIO()
old = sys.stdout
sys.stdout = buf
try:
_handle_upload(args)
finally:
sys.stdout = old

_, kwargs = mock_ws.upload_dataset.call_args
self.assertIsNone(kwargs.get("split"))

@patch("roboflow.Roboflow")
def test_upload_directory_forwards_explicit_split(self, mock_rf_cls):
from roboflow.cli.handlers.image import _handle_upload

with tempfile.TemporaryDirectory() as tmpdir:
mock_ws = MagicMock()
mock_rf_cls.return_value.workspace.return_value = mock_ws

args = _make_args(
json=True,
path=tmpdir,
project="proj",
annotation=None,
split="valid",
batch=None,
tag=None,
metadata=None,
concurrency=10,
retries=0,
labelmap=None,
is_prediction=False,
)

buf = io.StringIO()
old = sys.stdout
sys.stdout = buf
try:
_handle_upload(args)
finally:
sys.stdout = old

_, kwargs = mock_ws.upload_dataset.call_args
self.assertEqual(kwargs.get("split"), "valid")


class TestImageDelete(unittest.TestCase):
"""Test the delete handler."""
Expand Down
9 changes: 9 additions & 0 deletions tests/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,15 @@ def test_project_upload_dataset(self):
},
"assertions": {"upload": {"count": 1, "kwargs": {"batch_name": "test-batch", "num_retry_uploads": 3}}},
},
{
"name": "explicit_split_overrides_parsed_directory_splits",
"dataset": [
{"file": "image1.jpg", "split": "train"},
{"file": "image2.jpg", "split": "test"},
],
"params": {"split": "valid", "num_workers": 1},
"assertions": {"upload": {"count": 2, "kwargs": {"split": "valid"}}},
},
{
"name": "project_creation",
"dataset": None,
Expand Down
Loading