Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion nucleus/data_transfer_object/dataset_info.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
from typing import Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional

if TYPE_CHECKING:
from pydantic.v1 import validator
else:
try:
from pydantic.v1 import validator
except ImportError:
from pydantic import validator

from nucleus.pydantic_base import DictCompatibleModel

Expand All @@ -14,6 +22,7 @@ class DatasetInfo(DictCompatibleModel):
slice_ids: List :class:`Slice` IDs associated with the :class:`Dataset`
annotation_metadata_schema: Dict defining annotation-level metadata schema.
item_metadata_schema: Dict defining item metadata schema.
tags: List of tags associated with the :class:`Dataset`.
"""

dataset_id: str
Expand All @@ -24,3 +33,8 @@ class DatasetInfo(DictCompatibleModel):
# TODO: Expand the following into pydantic models to formalize schema
annotation_metadata_schema: Optional[Dict[str, Any]] = None
item_metadata_schema: Optional[Dict[str, Any]] = None
tags: List[str] = []

@validator("tags", pre=True, always=True) # pylint: disable=used-before-assignment
def coerce_null_tags(cls, v): # pylint: disable=no-self-argument
return v if v is not None else []
43 changes: 43 additions & 0 deletions nucleus/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,49 @@ def info(self) -> DatasetInfo:
dataset_info = DatasetInfo.parse_obj(response)
return dataset_info

def get_tags(self) -> List[str]:
"""Fetches tags associated with the dataset.

Returns:
List of tag strings associated with this dataset.
"""
response = self._client.make_request(
{}, f"dataset/{self.id}/tags", requests.get
)
return response["tags"]

def add_tags(self, tags: List[str]) -> List[str]:
"""Adds tags to the dataset.

Args:
tags: List of tag strings to add.

Returns:
Updated list of all tags on the dataset.
"""
if isinstance(tags, str):
raise TypeError("tags must be a list of strings, not a single string")
response = self._client.make_request(
{"tags": tags}, f"dataset/{self.id}/tags", requests.post
)
return response["tags"]

def remove_tags(self, tags: List[str]) -> List[str]:
"""Removes tags from the dataset.

Args:
tags: List of tag strings to remove.

Returns:
Updated list of remaining tags on the dataset.
"""
if isinstance(tags, str):
raise TypeError("tags must be a list of strings, not a single string")
response = self._client.make_request(
{"tags": tags}, f"dataset/{self.id}/tags", requests.delete
)
return response["tags"]

@deprecated(
"Model runs have been deprecated and will be removed. Use a Model instead"
)
Expand Down
35 changes: 35 additions & 0 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,41 @@ def test_dataset_slices(CLIENT, dataset):
# TODO(gunnar): Test slice items -> Split up info!


def test_dataset_tags(CLIENT, dataset):
# Fresh dataset should have no tags
assert dataset.get_tags() == []

# Add tags
updated = dataset.add_tags(["Labeled by: Scale", "production"])
assert "Labeled by: Scale" in updated
assert "production" in updated

# Info should include tags
info = dataset.info()
assert "Labeled by: Scale" in info.tags
assert "production" in info.tags

# Adding duplicate tags is idempotent
updated2 = dataset.add_tags(["production", "v2"])
assert "production" in updated2
assert "v2" in updated2

# Remove tags
remaining = dataset.remove_tags(["production"])
assert "production" not in remaining
assert "Labeled by: Scale" in remaining

# Removing non-existent tags is idempotent
remaining2 = dataset.remove_tags(["nonexistent"])
assert remaining2 == remaining

# String argument should raise TypeError
with pytest.raises(TypeError):
dataset.add_tags("not a list")
with pytest.raises(TypeError):
dataset.remove_tags("not a list")


def test_dataset_append_local(CLIENT, dataset):
ds_items_local_error = [
DatasetItem(
Expand Down