66from more_itertools import flatten
77import torch
88from torch .utils .data import Dataset
9- from datasets import Dataset as HGDataset
9+ from datasets import Dataset as HGDataset , DatasetDict as HGDatasetDict
1010from datasets import Sequence , ClassLabel
1111from transformers import (
1212 AutoModelForTokenClassification ,
@@ -316,9 +316,21 @@ def _tokenize_and_align_labels(
316316
317317def train_ner_model (
318318 hg_id : str ,
319- dataset : HGDataset ,
319+ dataset : Union [ HGDataset , HGDatasetDict ] ,
320320 targs : TrainingArguments ,
321+ train_split : str = "train" ,
322+ valid_split : str = "valid" ,
321323) -> PreTrainedModel :
324+ """Train a NER model on the given dataset.
325+
326+ :param hg_id: huggingface ID of the model to train
327+ :param dataset: huggingface dataset on which to train. The
328+ 'labels' column is assumed to contain NER labels.
329+ :param TrainingArguments: training arguments for the huggingface
330+ trainer.
331+ :param train_split: split of the dataset used for train.
332+ :param valid_split: split of the dataset used for validation.
333+ """
322334 from transformers import DataCollatorForTokenClassification
323335
324336 # BERT tokenizer splits tokens into subtokens. The
@@ -328,9 +340,8 @@ def train_ner_model(
328340 dataset = dataset .map (
329341 ft .partial (_tokenize_and_align_labels , tokenizer = tokenizer ), batched = True
330342 )
331- dataset = dataset .train_test_split (test_size = 0.1 )
332343
333- label_lst = dataset ["train" ].features ["labels" ].feature .names
344+ label_lst = dataset [train_split ].features ["labels" ].feature .names
334345 model = AutoModelForTokenClassification .from_pretrained (
335346 hg_id ,
336347 num_labels = len (label_lst ),
@@ -341,8 +352,8 @@ def train_ner_model(
341352 trainer = Trainer (
342353 model ,
343354 targs ,
344- train_dataset = dataset ["train" ],
345- eval_dataset = dataset ["test" ],
355+ train_dataset = dataset [train_split ],
356+ eval_dataset = dataset [valid_split ],
346357 # data_collator=DataCollatorForTokenClassificationWithBatchEncoding(tokenizer),
347358 data_collator = DataCollatorForTokenClassification (tokenizer ),
348359 tokenizer = tokenizer ,
0 commit comments