-
Notifications
You must be signed in to change notification settings - Fork 26
/
Copy pathgss_buffer.py
185 lines (164 loc) · 7.26 KB
/
gss_buffer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
# Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn.functional as F
import numpy as np
from typing import Tuple
from torchvision import transforms
class Buffer:
"""
The memory buffer of rehearsal method.
"""
def __init__(self, buffer_size, device, minibatch_size, model=None):
self.buffer_size = buffer_size
self.device = device
self.num_seen_examples = 0
self.attributes = ['examples', 'labels']
self.model = model
self.minibatch_size = minibatch_size
self.cache = {}
self.fathom = 0
self.fathom_mask = None
self.reset_fathom()
self.conterone = 0
def reset_fathom(self):
self.fathom = 0
self.fathom_mask = torch.randperm(min(self.num_seen_examples, self.examples.shape[0] if hasattr(self, 'examples') else self.num_seen_examples))
def get_grad_score(self, x, y, X, Y, indices):
g = self.model.get_grads(x, y)
G = []
for x, y, idx in zip(X, Y, indices):
if idx in self.cache:
grd = self.cache[idx]
else:
grd = self.model.get_grads(x.unsqueeze(0), y.unsqueeze(0))
self.cache[idx] = grd
G.append(grd)
G = torch.cat(G).to(g.device)
c_score = 0
grads_at_a_time = 5
# let's split this so your gpu does not melt. You're welcome.
for it in range(int(np.ceil(G.shape[0] / grads_at_a_time))):
tmp = F.cosine_similarity(g, G[it*grads_at_a_time: (it+1)*grads_at_a_time], dim=1).max().item() + 1
c_score = max(c_score, tmp)
return c_score
def functional_reservoir(self, x, y, batch_c, bigX=None, bigY=None, indices=None):
if self.num_seen_examples < self.buffer_size:
return self.num_seen_examples, batch_c
elif batch_c < 1:
single_c = self.get_grad_score(x.unsqueeze(0), y.unsqueeze(0), bigX, bigY, indices)
s = self.scores.cpu().numpy()
i = np.random.choice(np.arange(0, self.buffer_size), size=1, p=s / s.sum())[0]
rand = np.random.rand(1)[0]
# print(rand, s[i] / (s[i] + c))
if rand < s[i] / (s[i] + single_c):
return i, single_c
return -1, 0
def init_tensors(self, examples: torch.Tensor, labels: torch.Tensor) -> None:
"""
Initializes just the required tensors.
:param examples: tensor containing the images
:param labels: tensor containing the labels
:param logits: tensor containing the outputs of the network
:param task_labels: tensor containing the task labels
"""
for attr_str in self.attributes:
attr = eval(attr_str)
if attr is not None and not hasattr(self, attr_str):
typ = torch.int64 if attr_str.endswith('els') else torch.float32
setattr(self, attr_str, torch.zeros((self.buffer_size,
*attr.shape[1:]), dtype=typ, device=self.device))
self.scores = torch.zeros((self.buffer_size,*attr.shape[1:]),
dtype=torch.float32, device=self.device)
def add_data(self, examples, labels=None):
"""
Adds the data to the memory buffer according to the reservoir strategy.
:param examples: tensor containing the images
:param labels: tensor containing the labels
:param logits: tensor containing the outputs of the network
:param task_labels: tensor containing the task labels
:return:
"""
if not hasattr(self, 'examples'):
self.init_tensors(examples, labels)
# compute buffer score
if self.num_seen_examples > 0:
bigX, bigY, indices = self.get_data(min(self.minibatch_size, self.num_seen_examples), give_index=True,
random=True)
c = self.get_grad_score(examples, labels, bigX, bigY, indices)
else:
bigX, bigY, indices = None, None, None
c = 0.1
for i in range(examples.shape[0]):
index, score = self.functional_reservoir(examples[i], labels[i], c, bigX, bigY, indices)
self.num_seen_examples += 1
if index >= 0:
self.examples[index] = examples[i].to(self.device)
if labels is not None:
self.labels[index] = labels[i].to(self.device)
self.scores[index] = score
if index in self.cache:
del self.cache[index]
def drop_cache(self):
self.cache = {}
def get_data(self, size: int, transform: transforms=None, give_index=False, random=False) -> Tuple:
"""
Random samples a batch of size items.
:param size: the number of requested items
:param transform: the transformation to be applied (data augmentation)
:return:
"""
if size > self.examples.shape[0]:
size = self.examples.shape[0]
if random:
choice = np.random.choice(min(self.num_seen_examples, self.examples.shape[0]),
size=min(size, self.num_seen_examples),
replace=False)
else:
choice = np.arange(self.fathom, min(self.fathom + size, self.examples.shape[0], self.num_seen_examples))
choice = self.fathom_mask[choice]
self.fathom += len(choice)
if self.fathom >= self.examples.shape[0] or self.fathom >= self.num_seen_examples:
self.fathom = 0
if transform is None: transform = lambda x: x
ret_tuple = (torch.stack([transform(ee.cpu())
for ee in self.examples[choice]]).to(self.device),)
for attr_str in self.attributes[1:]:
if hasattr(self, attr_str):
attr = getattr(self, attr_str)
ret_tuple += (attr[choice],)
if give_index:
ret_tuple += (choice,)
return ret_tuple
def is_empty(self) -> bool:
"""
Returns true if the buffer is empty, false otherwise.
"""
if self.num_seen_examples == 0:
return True
else:
return False
def get_all_data(self, transform: transforms=None) -> Tuple:
"""
Return all the items in the memory buffer.
:param transform: the transformation to be applied (data augmentation)
:return: a tuple with all the items in the memory buffer
"""
if transform is None: transform = lambda x: x
ret_tuple = (torch.stack([transform(ee.cpu())
for ee in self.examples]).to(self.device),)
for attr_str in self.attributes[1:]:
if hasattr(self, attr_str):
attr = getattr(self, attr_str)
ret_tuple += (attr,)
return ret_tuple
def empty(self) -> None:
"""
Set all the tensors to None.
"""
for attr_str in self.attributes:
if hasattr(self, attr_str):
delattr(self, attr_str)
self.num_seen_examples = 0