-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpti_lda.py
84 lines (69 loc) · 3.03 KB
/
pti_lda.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from __future__ import print_function
from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer
from sklearn.decomposition import NMF, LatentDirichletAllocation
import pickle
class PTILDA:
n_features = 5000
n_topics = 300
tfv_file = "tfv.pkl"
tf_file = "tf.pkl"
model_file = "lda.pkl"
def __init__(self):
pass
def train(self, dataset):
# Use tf (raw term count) features for LDA.
self.tf_vectorizer = CountVectorizer(max_df=0.95, min_df=5,
max_features=self.n_features)
self.tf = self.tf_vectorizer.fit_transform(dataset)
# Train Model
self.lda = LatentDirichletAllocation(n_topics=self.n_topics, max_iter=5,
learning_method='online',
learning_offset=50.,
random_state=0)
self.lda.fit(self.tf)
# predict data
self.matrix = self.lda.transform(self.tf)
# save model
pickle.dump(self.tf_vectorizer, open(self.tfv_file, "wb"))
pickle.dump(self.tf, open(self.tf_file, "wb"))
pickle.dump(self.lda, open(self.model_file, "wb"))
def load(self):
self.tf_vectorizer = pickle.load(open(self.tfv_file, "rb"))
self.tf = pickle.load(open(self.tf_file, "rb"))
self.lda = pickle.load(open(self.model_file, "rb"))
self.matrix = self.lda.transform(self.tf)
def print_topic_word(self, n_top_words, topic_num=None):
feature_names = self.tf_vectorizer.get_feature_names()
for topic_idx, topic in enumerate(self.lda.components_):
if topic_num is None or topic_num == topic_idx:
print("Topic #{}:".format(topic_idx))
print(" ".join([feature_names[i]
for i in topic.argsort()[:-n_top_words - 1:-1]]))
print()
def print_doc_topic(self, n_top, doc_id=None):
for i, doc in enumerate(self.matrix):
if doc_id is None or i == doc_id:
print("Doc {}:".format(i))
topics = {}
for topic_num, prob in enumerate(doc):
topics[topic_num] = prob
count = 0
for k, v in sorted(topics.items(), key=lambda x:x[1], reverse=True):
print ("Topic:{}, prob:{}".format(k, v))
count += 1
if count >= n_top:
break
print()
def print_doc_threshold_by_topic(self, topic, threshold=0.5):
count = 0
for i, doc in enumerate(self.matrix):
if doc[topic] >= threshold:
print ("Doc:{}".format(i))
count += 1
print ("Count:{}".format(count))
def dump_topic(self, idx2doc):
for i, doc in enumerate(self.matrix):
result = []
for j, x in enumerate(doc):
result.append((j, x))
pickle.dump(result, open("/var/pti/topic/{}.pkl".format(idx2doc[i]), "wb"))