-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathclassifier.py
86 lines (63 loc) · 2.49 KB
/
classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import pickle
import requests
import torch
from transformers import BertTokenizer, BertForSequenceClassification
from src.nlu import Intent
from config import Config
config = Config()
# Set the device to cpu
device = torch.device("cpu")
class Classifier:
def __init__(self):
# Load the classes and the model
self.labels = self._load_labels()
self.model = self._load_model()
@staticmethod
def __load_remote_file(url: str, local: str):
# Open the URL and a local file
with requests.get(url, stream=True) as response:
with open(local, 'wb') as handle:
# Stream the model to the local file
for chunk in response.iter_content(chunk_size=8192):
handle.write(chunk)
def _load_labels(self) -> dict:
"""
Load the dictionary labels from a remote pickle file and return it.
"""
# Download and save the pickle locally
self.__load_remote_file(config.MODEL_CLASSES_URL, config.MODEL_CLASSES_LOCAL_COPY)
# Load and return a dictionary
with open(config.MODEL_CLASSES_LOCAL_COPY, 'rb') as handle:
return pickle.load(handle)
def _load_model(self) -> BertForSequenceClassification:
"""
Load the weight of the model from a remote file (around 500 Mo),
instantiate and return the model.
"""
# Download and save the weights locally
self.__load_remote_file(config.MODEL_WEIGHT_URL, config.MODEL_WEIGHT_LOCAL_COPY)
# Instantiate the model
model = BertForSequenceClassification.from_pretrained(
config.MODEL_CLASSIFIER,
num_labels=len(self.labels),
output_attentions=False,
output_hidden_states=False
)
model.to(device)
# Load and append the weights
model.load_state_dict(
torch.load(config.MODEL_WEIGHT_LOCAL_COPY, map_location=device)
)
return model
def predict(self, dataset: BertTokenizer) -> Intent:
"""Make a prediction and return the class."""
# Make the prediction, get an array of probabilities
probabilities = self.model(
input_ids=dataset.input_ids,
token_type_ids=None,
attention_mask=dataset.attention_mask
)
# Get the predicted class index
_, predicted_index = torch.max(probabilities[0], dim=1)
# Return the intent
return Intent(self.labels[predicted_index[0].item()])