This repository was archived by the owner on Mar 23, 2026. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 773
Expand file tree
/
Copy pathtext_classification.py
More file actions
66 lines (54 loc) · 2.59 KB
/
text_classification.py
File metadata and controls
66 lines (54 loc) · 2.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from typing import List
from taskweaver.plugin import Plugin, register_plugin
try:
from transformers import AutoModelForSequenceClassification, AutoTokenizer
except ImportError:
raise ImportError("Please install transformers with `pip install transformers`.")
try:
import torch
except ImportError:
raise ImportError(
"Please install torch according to your OS and CUDA availability. You may try `pip install torch`",
)
class TextClassificationModelInference:
"""This text classification model inference class is for zero-shot text classification using
Huggingface's transformers library. The method works by posing the sequence to be classified
as the NLI premise and to construct a hypothesis from each candidate label.
More details can be found at: https://huggingface.co/facebook/bart-large-mnli
"""
def __init__(self, model_name: str) -> None:
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = self.model.to(self.device)
self.entailment_id = -1
for idx, label in self.model.config.id2label.items():
if label.lower().startswith("entail"):
self.entailment_id = int(idx)
if self.entailment_id == -1:
raise ValueError("Could not determine the entailment ID from the model config, please pass it at init.")
def predict(self, inputs: List[str], label_list: List[str]) -> List[str]:
predicted_labels = []
for sequence in inputs:
tokenized_inputs = self.tokenizer(
[sequence] * len(label_list),
[f"This example is {label}" for label in label_list],
return_tensors="pt",
padding="max_length",
)
with torch.no_grad():
logits = self.model(**tokenized_inputs.to(self.device)).logits
label_id = torch.argmax(logits[:, 2]).item()
predicted_labels.append(label_list[label_id])
return predicted_labels
@register_plugin
class TextClassification(Plugin):
model: TextClassificationModelInference = None
def _init(self) -> None:
model_name = "facebook/bart-large-mnli"
self.model = TextClassificationModelInference(model_name)
def __call__(self, inputs: List[str], label_list: List[str]) -> List[str]:
if self.model is None:
self._init()
result = self.model.predict(inputs, label_list)
return result