From ccb48bcafcb3a8bec34100ada630e5f7d9a96161 Mon Sep 17 00:00:00 2001 From: Vinay Parakala Date: Mon, 6 Apr 2026 16:54:19 -0400 Subject: [PATCH] Add dataset tags to SDK for identification (DE-7033) Expose dataset tags through the Python SDK so customers can identify datasets labeled by Scale vs other vendors via the API. - Add `tags` field to DatasetInfo model (returned by dataset.info()) - Add get_tags(), add_tags(), remove_tags() methods to Dataset class - Use POST /tags/remove instead of DELETE to avoid proxy body-stripping - Use pydantic v1/v2 compat shim for null-coercion validator - Guard against passing a bare string instead of a list Co-Authored-By: Claude Opus 4.6 (1M context) --- nucleus/data_transfer_object/dataset_info.py | 16 +++++++- nucleus/dataset.py | 43 ++++++++++++++++++++ tests/test_dataset.py | 35 ++++++++++++++++ 3 files changed, 93 insertions(+), 1 deletion(-) diff --git a/nucleus/data_transfer_object/dataset_info.py b/nucleus/data_transfer_object/dataset_info.py index 6a9b9024..c1763843 100644 --- a/nucleus/data_transfer_object/dataset_info.py +++ b/nucleus/data_transfer_object/dataset_info.py @@ -1,4 +1,12 @@ -from typing import Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +if TYPE_CHECKING: + from pydantic.v1 import validator +else: + try: + from pydantic.v1 import validator + except ImportError: + from pydantic import validator from nucleus.pydantic_base import DictCompatibleModel @@ -14,6 +22,7 @@ class DatasetInfo(DictCompatibleModel): slice_ids: List :class:`Slice` IDs associated with the :class:`Dataset` annotation_metadata_schema: Dict defining annotation-level metadata schema. item_metadata_schema: Dict defining item metadata schema. + tags: List of tags associated with the :class:`Dataset`. """ dataset_id: str @@ -24,3 +33,8 @@ class DatasetInfo(DictCompatibleModel): # TODO: Expand the following into pydantic models to formalize schema annotation_metadata_schema: Optional[Dict[str, Any]] = None item_metadata_schema: Optional[Dict[str, Any]] = None + tags: List[str] = [] + + @validator("tags", pre=True, always=True) # pylint: disable=used-before-assignment + def coerce_null_tags(cls, v): # pylint: disable=no-self-argument + return v if v is not None else [] diff --git a/nucleus/dataset.py b/nucleus/dataset.py index d5707923..8fda1a3d 100644 --- a/nucleus/dataset.py +++ b/nucleus/dataset.py @@ -432,6 +432,49 @@ def info(self) -> DatasetInfo: dataset_info = DatasetInfo.parse_obj(response) return dataset_info + def get_tags(self) -> List[str]: + """Fetches tags associated with the dataset. + + Returns: + List of tag strings associated with this dataset. + """ + response = self._client.make_request( + {}, f"dataset/{self.id}/tags", requests.get + ) + return response["tags"] + + def add_tags(self, tags: List[str]) -> List[str]: + """Adds tags to the dataset. + + Args: + tags: List of tag strings to add. + + Returns: + Updated list of all tags on the dataset. + """ + if isinstance(tags, str): + raise TypeError("tags must be a list of strings, not a single string") + response = self._client.make_request( + {"tags": tags}, f"dataset/{self.id}/tags", requests.post + ) + return response["tags"] + + def remove_tags(self, tags: List[str]) -> List[str]: + """Removes tags from the dataset. + + Args: + tags: List of tag strings to remove. + + Returns: + Updated list of remaining tags on the dataset. + """ + if isinstance(tags, str): + raise TypeError("tags must be a list of strings, not a single string") + response = self._client.make_request( + {"tags": tags}, f"dataset/{self.id}/tags", requests.delete + ) + return response["tags"] + @deprecated( "Model runs have been deprecated and will be removed. Use a Model instead" ) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index bc567653..35f18477 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -257,6 +257,41 @@ def test_dataset_slices(CLIENT, dataset): # TODO(gunnar): Test slice items -> Split up info! +def test_dataset_tags(CLIENT, dataset): + # Fresh dataset should have no tags + assert dataset.get_tags() == [] + + # Add tags + updated = dataset.add_tags(["Labeled by: Scale", "production"]) + assert "Labeled by: Scale" in updated + assert "production" in updated + + # Info should include tags + info = dataset.info() + assert "Labeled by: Scale" in info.tags + assert "production" in info.tags + + # Adding duplicate tags is idempotent + updated2 = dataset.add_tags(["production", "v2"]) + assert "production" in updated2 + assert "v2" in updated2 + + # Remove tags + remaining = dataset.remove_tags(["production"]) + assert "production" not in remaining + assert "Labeled by: Scale" in remaining + + # Removing non-existent tags is idempotent + remaining2 = dataset.remove_tags(["nonexistent"]) + assert remaining2 == remaining + + # String argument should raise TypeError + with pytest.raises(TypeError): + dataset.add_tags("not a list") + with pytest.raises(TypeError): + dataset.remove_tags("not a list") + + def test_dataset_append_local(CLIENT, dataset): ds_items_local_error = [ DatasetItem(