Skip to content

Commit 01daaac

Browse files
committed
allow to specify the list of possible label when loading a dataset with hgdataset_from_conll2002
1 parent 27368b0 commit 01daaac

1 file changed

Lines changed: 9 additions & 4 deletions

File tree

renard/ner_utils.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ def hgdataset_from_conll2002(
236236
tag_conversion_map: Optional[Dict[str, str]] = None,
237237
separator: str = "\t",
238238
max_sent_len: Optional[int] = None,
239+
labels: Optional[list[str]] = None,
239240
**kwargs,
240241
) -> HGDataset:
241242
"""Load a CoNLL-2002 file as a Huggingface Dataset.
@@ -244,9 +245,13 @@ def hgdataset_from_conll2002(
244245
:param tag_conversion_map: passed to :func:`load_conll2002_bio`
245246
:param separator: passed to :func:`load_conll2002_bio`
246247
:param max_sent_len: passed to :func:`load_conll2002_bio`
248+
:param labels: the list of all possible labels. If ``None``, will
249+
automatically be assigned to the sorted list of possible tags
250+
found in the input file.
247251
:param kwargs: additional kwargs for :func:`open`
248252
249-
:return: a :class:`datasets.Dataset` with features 'tokens' and 'labels'.
253+
:return: a :class:`datasets.Dataset` with features 'tokens' and
254+
'labels'.
250255
"""
251256
sentences, tokens, entities = load_conll2002_bio(
252257
path, tag_conversion_map, separator, max_sent_len, **kwargs
@@ -269,9 +274,9 @@ def hgdataset_from_conll2002(
269274
]
270275

271276
dataset = HGDataset.from_dict({"tokens": sentences, "labels": sent_tags})
272-
dataset = dataset.cast_column(
273-
"labels", Sequence(ClassLabel(names=sorted(set(tags))))
274-
)
277+
if labels is None:
278+
labels = sorted(set(tags))
279+
dataset = dataset.cast_column("labels", Sequence(ClassLabel(names=labels)))
275280
return dataset
276281

277282

0 commit comments

Comments
 (0)