-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathutils.py
More file actions
85 lines (72 loc) · 2.42 KB
/
utils.py
File metadata and controls
85 lines (72 loc) · 2.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os
import logging
import sys
from sklearn.manifold import TSNE
from sklearn.manifold import MDS
import seaborn as sns
import matplotlib.pyplot as plt
from models.spike import LIFSpike, AvgMeter
import torch.nn as nn
def _logger(logger_name, level=logging.DEBUG):
"""
Method to return a custom logger with the given name and level
"""
logger = logging.getLogger(logger_name)
logger.setLevel(level)
format_string = "%(message)s"
log_format = logging.Formatter(format_string)
# Creating and adding the console handler
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setFormatter(log_format)
logger.addHandler(console_handler)
# Creating and adding the file handler
file_handler = logging.FileHandler(logger_name, mode='a')
file_handler.setFormatter(log_format)
logger.addHandler(file_handler)
return logger
def tsne(latent, y_ground_truth, save_dir):
"""
Plot t-SNE embeddings of the features
"""
latent = latent.cpu().detach().numpy()
# y_ground_truth = y_ground_truth.cpu().detach().numpy()
tsne = TSNE(n_components=2, verbose=1, perplexity=40, n_iter=300)
tsne_results = tsne.fit_transform(latent)
plt.figure(figsize=(16,10))
set_y = set(y_ground_truth)
num_labels = len(set_y)
sns_plot = sns.scatterplot(
x=tsne_results[:,0], y=tsne_results[:,1],
hue=y_ground_truth,
palette=sns.color_palette("hls", num_labels),
legend="full",
alpha = 0.5
)
sns_plot.get_figure().savefig(save_dir)
def mds(latent, y_ground_truth, save_dir):
"""
Plot MDS embeddings of the features
"""
latent = latent.cpu().detach().numpy()
mds = MDS(n_components=2)
mds_results = mds.fit_transform(latent)
plt.figure(figsize=(16,10))
set_y = set(y_ground_truth)
num_labels = len(set_y)
sns_plot = sns.scatterplot(
x=mds_results[:,0], y=mds_results[:,1],
hue=y_ground_truth,
palette=sns.color_palette("hls", num_labels),
# data=df_subset,
legend="full",
alpha=0.5
)
sns_plot.get_figure().savefig(save_dir)
def hook_layers(model: nn.Module):
avgmeter = AvgMeter()
def save(layer, inputs, outputs):
avgmeter.add(outputs.count_nonzero(), outputs.numel())
for m in model.modules():
if isinstance(m, (nn.ReLU, LIFSpike)):
m.register_forward_hook(save)
return avgmeter