Skip to content

Commit cba71f4

Browse files
committed
change huggingface prefix from HG to HF
1 parent dfd8fe4 commit cba71f4

2 files changed

Lines changed: 25 additions & 17 deletions

File tree

renard/ner_utils.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from more_itertools import flatten
77
import torch
88
from torch.utils.data import Dataset
9-
from datasets import Dataset as HGDataset, DatasetDict as HGDatasetDict
9+
from datasets import Dataset as HFDataset, DatasetDict as HFDatasetDict
1010
from datasets import Sequence, ClassLabel
1111
from transformers import (
1212
AutoModelForTokenClassification,
@@ -235,14 +235,14 @@ def load_conll2002_bio(
235235
return sents, list(flatten(sents)), entities
236236

237237

238-
def hgdataset_from_conll2002(
238+
def hfdataset_from_conll2002(
239239
path: str,
240240
tag_conversion_map: Optional[Dict[str, str]] = None,
241241
separator: str = "\t",
242242
max_sent_len: Optional[int] = None,
243243
labels: Optional[list[str]] = None,
244244
**kwargs,
245-
) -> HGDataset:
245+
) -> HFDataset:
246246
"""Load a CoNLL-2002 file as a Huggingface Dataset.
247247
248248
:param path: passed to :func:`.load_conll2002_bio`
@@ -277,13 +277,21 @@ def hgdataset_from_conll2002(
277277
for sent_start, sent_end in zip(sent_starts, sent_ends)
278278
]
279279

280-
dataset = HGDataset.from_dict({"tokens": sentences, "labels": sent_tags})
280+
dataset = HFDataset.from_dict({"tokens": sentences, "labels": sent_tags})
281281
if labels is None:
282282
labels = sorted(set(tags))
283283
dataset = dataset.cast_column("labels", Sequence(ClassLabel(names=labels)))
284284
return dataset
285285

286286

287+
def hgdataset_from_conll2002(**kwargs) -> HFDataset:
288+
"""
289+
Deprecated function that only exists for retrocompatibility, you
290+
should call :func:`.hfdataset_from_conll2002` instead.
291+
"""
292+
return hfdataset_from_conll2002(**kwargs)
293+
294+
287295
def _tokenize_and_align_labels(
288296
examples, tokenizer: PreTrainedTokenizerFast, label_all_tokens: bool = True
289297
):
@@ -324,15 +332,15 @@ def _tokenize_and_align_labels(
324332

325333

326334
def train_ner_model(
327-
hg_id: str,
328-
dataset: Union[HGDataset, HGDatasetDict],
335+
hf_id: str,
336+
dataset: Union[HFDataset, HFDatasetDict],
329337
targs: TrainingArguments,
330338
train_split: str = "train",
331339
valid_split: str = "valid",
332340
) -> PreTrainedModel:
333341
"""Train a NER model on the given dataset.
334342
335-
:param hg_id: huggingface ID of the model to train
343+
:param hf_id: huggingface ID of the model to train
336344
:param dataset: huggingface dataset on which to train. The
337345
'labels' column is assumed to contain NER labels.
338346
:param TrainingArguments: training arguments for the huggingface
@@ -345,14 +353,14 @@ def train_ner_model(
345353
# BERT tokenizer splits tokens into subtokens. The
346354
# tokenize_and_align_labels function correctly aligns labels and
347355
# subtokens.
348-
tokenizer = AutoTokenizer.from_pretrained(hg_id)
356+
tokenizer = AutoTokenizer.from_pretrained(hf_id)
349357
dataset = dataset.map(
350358
ft.partial(_tokenize_and_align_labels, tokenizer=tokenizer), batched=True
351359
)
352360

353361
label_lst = dataset[train_split].features["labels"].feature.names
354362
model = AutoModelForTokenClassification.from_pretrained(
355-
hg_id,
363+
hf_id,
356364
num_labels=len(label_lst),
357365
id2label={i: label for i, label in enumerate(label_lst)},
358366
label2id={label: i for i, label in enumerate(label_lst)},

renard/pipeline/relation_extraction.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Any, Union, Optional, Literal
22
import ast, re
33
import functools as ft
4-
from datasets import load_dataset, Dataset as HGDataset
4+
from datasets import load_dataset, Dataset as HFDataset
55
import torch
66
from transformers import (
77
AutoModelForSeq2SeqLM,
@@ -12,7 +12,7 @@
1212
DataCollatorForSeq2Seq,
1313
PreTrainedModel,
1414
EvalPrediction,
15-
pipeline as hg_pipeline,
15+
pipeline as hf_pipeline,
1616
BatchEncoding,
1717
)
1818
from more_itertools import flatten
@@ -46,7 +46,7 @@ def format_rel(rel: dict) -> str:
4646
return batch
4747

4848

49-
def load_ARF_dataset(tokenizer: PreTrainedTokenizerFast) -> HGDataset:
49+
def load_ARF_dataset(tokenizer: PreTrainedTokenizerFast) -> HFDataset:
5050
"""
5151
Load the Artificial Relationships in Fiction dataset
5252
(https://huggingface.co/datasets/Despina/project_gutenberg) by
@@ -154,7 +154,7 @@ def __init__(
154154
self.model = (
155155
GenerativeRelationExtractor.DEFAULT_MODEL if model is None else model
156156
)
157-
self.hg_pipeline = None
157+
self.hf_pipeline = None
158158
self.batch_size = batch_size
159159
if device == "auto":
160160
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -163,7 +163,7 @@ def __init__(
163163

164164
def _pipeline_init_(self, lang: str, progress_reporter: ProgressReporter, **kwargs):
165165
super()._pipeline_init_(lang, progress_reporter, **kwargs)
166-
self.hg_pipeline = hg_pipeline(
166+
self.hf_pipeline = hf_pipeline(
167167
"text2text-generation",
168168
torch_dtype=torch.bfloat16,
169169
model=self.model,
@@ -173,19 +173,19 @@ def _pipeline_init_(self, lang: str, progress_reporter: ProgressReporter, **kwar
173173
def __call__(
174174
self, sentences: list[list[str]], characters: list[Character], **kwargs
175175
) -> dict[str, Any]:
176-
assert not self.hg_pipeline is None
176+
assert not self.hf_pipeline is None
177177

178178
sentence_relations = []
179179

180180
# chunk as in the ARF dataset
181-
dataset = HGDataset.from_list(
181+
dataset = HFDataset.from_list(
182182
[
183183
{"text": GenerativeRelationExtractor.task_prompt(" ".join(sent))}
184184
for sent in sentences
185185
]
186186
)
187187
for out in self._progress_(
188-
self.hg_pipeline(KeyDataset(dataset, "text"), batch_size=self.batch_size),
188+
self.hf_pipeline(KeyDataset(dataset, "text"), batch_size=self.batch_size),
189189
total=len(dataset),
190190
):
191191
text_relations = out[0]["generated_text"]

0 commit comments

Comments
 (0)