Skip to content

Commit 27368b0

Browse files
committed
add possibility to pick dataset splits for NER training
1 parent 6d21e9f commit 27368b0

1 file changed

Lines changed: 17 additions & 6 deletions

File tree

renard/ner_utils.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from more_itertools import flatten
77
import torch
88
from torch.utils.data import Dataset
9-
from datasets import Dataset as HGDataset
9+
from datasets import Dataset as HGDataset, DatasetDict as HGDatasetDict
1010
from datasets import Sequence, ClassLabel
1111
from transformers import (
1212
AutoModelForTokenClassification,
@@ -316,9 +316,21 @@ def _tokenize_and_align_labels(
316316

317317
def train_ner_model(
318318
hg_id: str,
319-
dataset: HGDataset,
319+
dataset: Union[HGDataset, HGDatasetDict],
320320
targs: TrainingArguments,
321+
train_split: str = "train",
322+
valid_split: str = "valid",
321323
) -> PreTrainedModel:
324+
"""Train a NER model on the given dataset.
325+
326+
:param hg_id: huggingface ID of the model to train
327+
:param dataset: huggingface dataset on which to train. The
328+
'labels' column is assumed to contain NER labels.
329+
:param TrainingArguments: training arguments for the huggingface
330+
trainer.
331+
:param train_split: split of the dataset used for train.
332+
:param valid_split: split of the dataset used for validation.
333+
"""
322334
from transformers import DataCollatorForTokenClassification
323335

324336
# BERT tokenizer splits tokens into subtokens. The
@@ -328,9 +340,8 @@ def train_ner_model(
328340
dataset = dataset.map(
329341
ft.partial(_tokenize_and_align_labels, tokenizer=tokenizer), batched=True
330342
)
331-
dataset = dataset.train_test_split(test_size=0.1)
332343

333-
label_lst = dataset["train"].features["labels"].feature.names
344+
label_lst = dataset[train_split].features["labels"].feature.names
334345
model = AutoModelForTokenClassification.from_pretrained(
335346
hg_id,
336347
num_labels=len(label_lst),
@@ -341,8 +352,8 @@ def train_ner_model(
341352
trainer = Trainer(
342353
model,
343354
targs,
344-
train_dataset=dataset["train"],
345-
eval_dataset=dataset["test"],
355+
train_dataset=dataset[train_split],
356+
eval_dataset=dataset[valid_split],
346357
# data_collator=DataCollatorForTokenClassificationWithBatchEncoding(tokenizer),
347358
data_collator=DataCollatorForTokenClassification(tokenizer),
348359
tokenizer=tokenizer,

0 commit comments

Comments
 (0)