1
+ # A partial implementation of https://arxiv.org/abs/2109.02157
2
+
3
+ import torch
4
+ import torch .nn as nn
5
+ import torch .nn .functional as F
6
+ from torch .utils .data import Dataset , DataLoader
7
+ import torch .optim as optim
8
+ import torch .optim .lr_scheduler as lr_scheduler
9
+
10
+ # Note: this example requires the napkinXC library: https://napkinxc.readthedocs.io/
11
+ from napkinxc .datasets import load_dataset
12
+ from napkinxc .measures import precision_at_k
13
+
14
+ from tqdm import tqdm
15
+ import torchhd
16
+ from torchhd import embeddings , HRRTensor
17
+ import torchhd .tensors
18
+ from scipy .sparse import vstack , lil_matrix
19
+
20
+
21
+ device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
22
+ print ("Using {} device" .format (device ))
23
+
24
+
25
+ DIMENSIONS = 400
26
+ NUMBER_OF_EPOCHS = 1
27
+ BATCH_SIZE = 1
28
+ DATASET_NAME = "eurlex-4k" # tested on "eurlex-4k", and "Wiki10-31K"
29
+ FC_LAYER_SIZE = 512
30
+
31
+
32
+ def sparse_batch_collate (batch :list ):
33
+ """
34
+ Collate function which to transform scipy csr matrix to pytorch sparse tensor
35
+ """
36
+ data_batch , targets_batch = zip (* batch )
37
+
38
+ data_batch = vstack (data_batch ).tocoo ()
39
+ data_batch = torch .sparse_coo_tensor (data_batch .nonzero (), data_batch .data , data_batch .shape )
40
+
41
+ targets_batch = torch .stack (targets_batch )
42
+
43
+ return data_batch , targets_batch
44
+
45
+ class multilabel_dataset (Dataset ):
46
+ def __init__ (self ,x ,y ,n_classes ) -> None :
47
+ self .x = x
48
+ self .y = y
49
+ self .n_classes = n_classes
50
+
51
+
52
+ # Define the length of the dataset.
53
+ def __len__ (self ):
54
+ return self .x .shape [0 ]
55
+
56
+ # Return a single sample from the dataset.
57
+ def __getitem__ (self , idx ):
58
+ labels = torch .zeros (self .n_classes , dtype = torch .int64 )
59
+ labels [self .y [idx ]] = 1.0
60
+ return self .x [idx ], labels
61
+
62
+
63
+ X_train , Y_train = load_dataset (DATASET_NAME , "train" , verbose = True )
64
+ X_test , Y_test = load_dataset (DATASET_NAME , "test" , verbose = True )
65
+
66
+
67
+ if DATASET_NAME == "Wiki10-31K" : # Because of this issue https://github.com/mwydmuch/napkinXC/issues/18
68
+ X_train = lil_matrix (X_train [:,:- 1 ])
69
+
70
+ N_freatures = X_train .shape [1 ]
71
+ N_classes = max (max (classes ) for classes in Y_train if classes != []) + 1
72
+
73
+ train_dataset = multilabel_dataset (X_train ,Y_train ,N_classes )
74
+ train_dataloader = DataLoader (train_dataset ,BATCH_SIZE , collate_fn = sparse_batch_collate )
75
+ test_dataset = multilabel_dataset (X_test ,Y_test ,N_classes )
76
+ test_dataloader = DataLoader (test_dataset ,collate_fn = sparse_batch_collate )
77
+
78
+
79
+ print ("Traning on \033 [1m {} \033 [0m. It has {} features, and {} classes."
80
+ .format (DATASET_NAME ,N_freatures ,N_classes ))
81
+
82
+
83
+ # Fully Connected model for the baseline comparision
84
+ class FC (nn .Module ):
85
+ def __init__ (self , num_features , num_classes ):
86
+ super (FC , self ).__init__ ()
87
+ self .num_classes = num_classes
88
+ self .num_features = num_features
89
+ self .fc_layer_size = FC_LAYER_SIZE
90
+
91
+ # Network Layers
92
+ self .fc1 = nn .Linear (self .num_features , self .fc_layer_size )
93
+ self .fc2 = nn .Linear (self .fc_layer_size , self .fc_layer_size )
94
+ self .olayer = nn .Linear (self .fc_layer_size , self .num_classes )
95
+
96
+ def forward (self , x ):
97
+ x = F .leaky_relu (self .fc1 (x ))
98
+ x = F .leaky_relu (self .fc2 (x ))
99
+ x = self .olayer (x )
100
+ return x
101
+
102
+ def pred (self , out ,threshold = 0.5 ):
103
+ y = F .sigmoid (out )
104
+ v ,i = y .sort (descending = True )
105
+ ids = i [v >= threshold ]
106
+ ids = ids .tolist ()
107
+ return ids
108
+
109
+ def loss (self ,out ,target ):
110
+ loss = nn .BCEWithLogitsLoss ()(out , target .type (torch .float64 ))
111
+ return loss
112
+
113
+ # Modified version of FC model that returns an HRRTensor with dim << output of the FC model.
114
+ # It makes the model to have fewer parameters
115
+ class FCHRR (nn .Module ):
116
+ def __init__ (self , num_features , num_classes ,dim ):
117
+ super (FCHRR , self ).__init__ ()
118
+ self .num_classes = num_classes
119
+ self .num_features = num_features
120
+ self .fc_layer_size = FC_LAYER_SIZE
121
+ self .dim = dim
122
+
123
+ self .classes_vec = embeddings .Random (N_classes , dim ,vsa = "HRR" )
124
+ n_vec , p_vec = torchhd .HRRTensor .random (2 ,dim )
125
+ self .register_buffer ("n_vec" , n_vec )
126
+ self .register_buffer ("p_vec" , p_vec )
127
+
128
+ # Network Layers
129
+ self .fc1 = nn .Linear (self .num_features , self .fc_layer_size )
130
+ self .fc2 = nn .Linear (self .fc_layer_size , self .fc_layer_size )
131
+ self .olayer = nn .Linear (self .fc_layer_size , dim )
132
+
133
+ def forward (self , x ):
134
+ x = F .leaky_relu (self .fc1 (x ))
135
+ x = F .leaky_relu (self .fc2 (x ))
136
+ x = self .olayer (x )
137
+ return x .as_subclass (HRRTensor )
138
+
139
+ def pred (self , out ,threshold = 0.1 ):
140
+
141
+ tmp_positive = self .p_vec .exact_inverse ().bind (out )
142
+ sims = tmp_positive .cosine_similarity (self .classes_vec .weight )
143
+
144
+ v ,i = sims .sort (descending = True )
145
+ ids = i [v >= threshold ]
146
+ ids = ids .tolist ()
147
+
148
+ return ids
149
+
150
+ def loss (self ,out ,target ):
151
+
152
+ loss = torch .tensor (0 , dtype = torch .float32 ,device = device )
153
+
154
+ tmp_positives = self .p_vec .exact_inverse ().bind (out )
155
+ tmp_negatives = self .n_vec .exact_inverse ().bind (out )
156
+ for i in range (target .shape [0 ]):
157
+
158
+ cp = self .classes_vec .weight [target [i ]== 1 ,:]
159
+
160
+ j_p = (1 - tmp_positives [i ].cosine_similarity (cp )).sum ()
161
+ j_n = tmp_negatives [i ].cosine_similarity (cp .multibundle ())
162
+
163
+ loss += j_p + j_n
164
+
165
+ loss /= target .shape [0 ]
166
+
167
+ return loss
168
+
169
+
170
+
171
+ hrr_model = FCHRR (N_freatures ,N_classes ,DIMENSIONS )
172
+ hrr_model = hrr_model .to (device )
173
+
174
+ baseline_model = FC (N_freatures ,N_classes )
175
+ baseline_model = baseline_model .to (device )
176
+
177
+
178
+ for model_name , model in {"HRR-FC" :hrr_model ,"FC" :baseline_model }.items ():
179
+ optimizer = optim .Adam (model .parameters (), lr = 0.001 )
180
+ scheduler = lr_scheduler .StepLR (optimizer , step_size = 1 , gamma = 0.7 )
181
+ model .train ()
182
+ for epoch in tqdm (range (1 ,NUMBER_OF_EPOCHS + 1 ), desc = f"{ model_name } epochs" ,leave = False ):
183
+
184
+ for samples , labels in tqdm (train_dataloader , desc = "Training" ,leave = False ):
185
+ samples = samples .to (device )
186
+ labels = labels .to (device )
187
+ optimizer .zero_grad ()
188
+ out = model (samples )
189
+ loss = model .loss (out , labels )
190
+ loss .backward ()
191
+ optimizer .step ()
192
+
193
+ scheduler .step ()
194
+
195
+ Y_pred = []
196
+ model .eval ()
197
+ with torch .no_grad ():
198
+ for data , target in tqdm (test_dataloader ,desc = "Validating" ,leave = False ):
199
+ data , target = data .to (device ).float (), target .to (device )
200
+ out = model (data )
201
+ ids = model .pred (out )
202
+ Y_pred .append (ids )
203
+
204
+ # Calculating the P@1 metric
205
+ p_at_1 = precision_at_k (Y_test , Y_pred , k = 1 )[0 ]
206
+ print ("Result of {} model ----> P@1 = {}" .format (model_name , p_at_1 ))
0 commit comments