-
Notifications
You must be signed in to change notification settings - Fork 31
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
289 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,277 @@ | ||
import tensorflow as tf | ||
from sklearn.model_selection import train_test_split | ||
from sklearn.cluster import DBSCAN | ||
import matplotlib.pyplot as plt | ||
import networkx as nx | ||
import numpy as np | ||
import time | ||
|
||
from kgcnn.literature.GNNExplain import GNNExplainer, GNNInterface | ||
# from kgcnn.utils.adj import precompute_adjacency_scaled, convert_scaled_adjacency_to_list, add_self_loops_to_edge_indices | ||
from kgcnn.literature.GCN import make_gcn | ||
from kgcnn.utils.data import ragged_tensor_from_nested_numpy | ||
from kgcnn.utils.learning import lr_lin_reduction | ||
|
||
from scipy.cluster.hierarchy import dendrogram, linkage | ||
from scipy.spatial.distance import cdist | ||
from sklearn.cluster import AgglomerativeClustering | ||
|
||
from kgcnn.data.mutagen.mutagenicity import mutagenicity_graph | ||
|
||
labels, nodes, edge_indices, edges, atoms = mutagenicity_graph() | ||
for i in range(len(labels)): | ||
# edge_indices[i], edges[i] = add_self_loops_to_edge_indices(edge_indices[i], np.expand_dims(edges[i],axis=-1)) | ||
edges[i] = np.expand_dims(edges[i], axis=-1).astype(np.float32) # Make edge feature dimension | ||
for i in range(len(labels)): | ||
nodes[i] = np.array( | ||
np.expand_dims(nodes[i],axis=-1) == np.array([[ 1, 3, 6, 7, 8, 9, 11, 15, 16, 17, 19, 20, 35, 53]]) | ||
, dtype=np.int) # Make One-Hot encoding | ||
|
||
# Train Test split | ||
labels_train, labels_test, nodes_train, nodes_test, edges_train, edges_test, edge_indices_train, edge_indices_test = train_test_split( | ||
labels, nodes, edges, edge_indices, train_size=0.8, random_state=1) | ||
|
||
# Convert to tf.RaggedTensor or tf.tensor | ||
# adj_matrix copy of the data is generated by ragged_tensor_from_nested_numpy() | ||
nodes_train, edges_train, edge_indices_train = ragged_tensor_from_nested_numpy( | ||
nodes_train), ragged_tensor_from_nested_numpy(edges_train), ragged_tensor_from_nested_numpy( | ||
edge_indices_train) | ||
|
||
nodes_test, edges_test, edge_indices_test = ragged_tensor_from_nested_numpy( | ||
nodes_test), ragged_tensor_from_nested_numpy(edges_test), ragged_tensor_from_nested_numpy( | ||
edge_indices_test) | ||
|
||
xtrain = nodes_train, edges_train, edge_indices_train | ||
xtest = nodes_test, edges_test, edge_indices_test | ||
ytrain = np.expand_dims(labels_train, axis=-1) | ||
ytest = np.expand_dims(labels_test, axis=-1) | ||
|
||
model = make_gcn( | ||
input_node_shape=[None,14], | ||
input_edge_shape=[None, 1], | ||
input_embedd={'input_node_vocab': 55, "input_node_embedd": 64}, | ||
# Output | ||
output_embedd={"output_mode": 'graph', "output_type": 'padded'}, | ||
output_mlp={"use_bias": [True, True, False], "units": [140, 70, 1], "activation": ['relu', 'relu', 'sigmoid']}, | ||
# model specs | ||
depth=3, | ||
gcn_args={"units": 64, "use_bias": True, "activation": "relu", "has_unconnected": True, "is_sorted": False, "pooling_method": 'segment_mean'} | ||
) | ||
|
||
# Set learning rate and epochs | ||
learning_rate_start = 1e-3 | ||
learning_rate_stop = 1e-4 | ||
epo = 150 | ||
epomin = 100 | ||
epostep = 10 | ||
|
||
# Compile model with optimizer and loss | ||
optimizer = tf.keras.optimizers.Adam(lr=learning_rate_start) | ||
cbks = tf.keras.callbacks.LearningRateScheduler(lr_lin_reduction(learning_rate_start, learning_rate_stop, epomin, epo)) | ||
model.compile(loss='binary_crossentropy', | ||
optimizer=optimizer, | ||
weighted_metrics=['accuracy']) | ||
print(model.summary()) | ||
|
||
# Start and time training | ||
start = time.process_time() | ||
hist = model.fit(xtrain, ytrain, | ||
epochs=epo, | ||
batch_size=32, | ||
callbacks=[cbks], | ||
validation_freq=epostep, | ||
validation_data=(xtest, ytest), | ||
verbose=2 | ||
) | ||
stop = time.process_time() | ||
print("Print Time for taining: ", stop - start) | ||
|
||
# Get loss from history | ||
trainlossall = np.array(hist.history['accuracy']) | ||
testlossall = np.array(hist.history['val_accuracy']) | ||
acc_valid = testlossall[-1] | ||
|
||
# Plot loss vs epochs | ||
plt.figure() | ||
plt.plot(np.arange(trainlossall.shape[0]), trainlossall, label='Training ACC', c='blue') | ||
plt.plot(np.arange(epostep, epo + epostep, epostep), testlossall, label='Test ACC', c='red') | ||
plt.scatter([trainlossall.shape[0]], [acc_valid], label="{0:0.4f} ".format(acc_valid), c='red') | ||
plt.xlabel('Epochs') | ||
plt.ylabel('Accuracy') | ||
plt.title('Interaction Network Loss') | ||
plt.legend(loc='upper right', fontsize='x-large') | ||
plt.savefig('gcn_explain_mutag_3.png') | ||
plt.show() | ||
|
||
# We net to implement the ExplainableGNN from GNNInterface | ||
|
||
class ExplainableGCN(GNNInterface): | ||
|
||
def __init__(self, gnn_model, **kwargs): | ||
super(ExplainableGCN, self).__init__() | ||
self.gnn_model = gnn_model | ||
|
||
def predict(self, gnn_input, masking_info=None): | ||
return self.gnn_model(gnn_input, training=False)[0] | ||
|
||
def masked_predict(self, gnn_input, edge_mask, feature_mask, node_mask, training=False): | ||
node_input, edge_input, edge_index_input = gnn_input | ||
|
||
masked_edge_input = tf.ragged.map_flat_values(tf.math.multiply, edge_input, edge_mask) | ||
masked_feature_input = tf.ragged.map_flat_values(tf.math.multiply, tf.dtypes.cast(node_input, tf.float32), | ||
tf.transpose(feature_mask)) | ||
masked_node_feature_input = tf.ragged.map_flat_values(tf.math.multiply, masked_feature_input, node_mask) | ||
masked_pred = \ | ||
self.gnn_model([masked_node_feature_input, masked_edge_input, edge_index_input], training=training)[0] | ||
|
||
return masked_pred | ||
|
||
def get_number_of_nodes(self, gnn_input): | ||
node_input, _, _ = gnn_input | ||
return node_input[0].shape[0] | ||
|
||
def get_number_of_node_features(self, gnn_input): | ||
node_input, _, _ = gnn_input | ||
return node_input.shape[2] | ||
|
||
def get_number_of_edges(self, gnn_input): | ||
_, edge_input, _ = gnn_input | ||
return edge_input[0].shape[0] | ||
|
||
def get_explanation(self, gnn_input, edge_mask, feature_mask, node_mask): | ||
edge_relevance = np.array(edge_mask[:, 0]) | ||
node_relevance = np.array(node_mask[:, 0]) | ||
feature_relevance = np.array(feature_mask[:, 0]) | ||
features = np.array(gnn_input[0][0]) | ||
edges = np.array(gnn_input[2][0]) | ||
graph = nx.Graph() | ||
for i, f in enumerate(features): | ||
graph.add_node(i, features=f, relevance=node_relevance[i]) | ||
for i, e in enumerate(edges): | ||
if edge_relevance is None: | ||
graph.add_edge(e[0], e[1]) | ||
else: | ||
graph.add_edge(e[0], e[1], relevance=edge_relevance[i]) | ||
return graph, feature_relevance | ||
|
||
def present_explanation(self, explanation, threshold=0.5): | ||
graph = explanation[0] | ||
# element_labels = np.array([[ 1, 3, 6, 7, 8, 9, 11, 15, 16, 17, 19, 20, 35, 53]]) | ||
element_labels = ['H', 'Li', 'C', 'N', 'O', 'F', 'Na', 'P', 'S', 'Cl', 'K', 'Ca', 'Br', 'I'] | ||
important_edges = [] | ||
color_map = [] | ||
node_color_map = [] | ||
node_labels = {} | ||
for (u, v, relevance) in graph.edges.data('relevance'): | ||
relevance = min(relevance + 0.1, 1.0) | ||
color_map.append((0, 0, 0, relevance)) | ||
for n, f in graph.nodes.data('features'): | ||
element = np.argmax(f) | ||
r, g, b, a = plt.get_cmap('tab20')(element) | ||
node_color_map.append((r, g, b, graph.nodes[n]['relevance'])) | ||
node_labels[n] = (element_labels[element]) | ||
if np.all(explanation[1] == 1): | ||
nx.draw_kamada_kawai(graph, edge_color=color_map, labels=node_labels, node_color=node_color_map) | ||
else: | ||
f, axs = plt.subplots(2, figsize=(8, 12)) | ||
nx.draw_kamada_kawai(graph, ax=axs[0], edge_color=color_map, labels=node_labels, node_color=node_color_map) | ||
bar_colors = [plt.get_cmap('tab20')(element) for element in np.arange(14)] | ||
axs[1].bar(np.array(element_labels), explanation[1], color=bar_colors) | ||
|
||
|
||
# Instanciate a Explainable GNN: | ||
explainable_gcn = ExplainableGCN(model) | ||
compile_options = {'loss': 'binary_crossentropy', 'optimizer': tf.keras.optimizers.Adam(lr=0.2)} | ||
fit_options={'epochs': 100, 'batch_size': 1, 'verbose': 0} | ||
gnnexplaineroptimizer_options = {'edge_mask_loss_weight': 0.001, | ||
'edge_mask_norm_ord': 1, | ||
'feature_mask_loss_weight': 0, | ||
'feature_mask_norm_ord': 1, | ||
'node_mask_loss_weight': 0, | ||
'node_mask_norm_ord': 1} | ||
|
||
explainer = GNNExplainer(explainable_gcn, | ||
compile_options=compile_options, | ||
fit_options=fit_options, | ||
gnnexplaineroptimizer_options=gnnexplaineroptimizer_options) | ||
|
||
inspection_result = explainer.explain([tensor[776:777] for tensor in xtest], inspection=True) | ||
#inspection_result = explainer.explain([tensor[776:777] for tensor in xtest], output_to_explain=tf.Variable([0.]), inspection=True) | ||
|
||
explainer.present_explanation(explainer.get_explanation(), threshold=0.5) | ||
|
||
# Plot predicion | ||
plt.figure() | ||
plt.plot(inspection_result['predictions']) | ||
plt.xlabel('Iterations') | ||
plt.ylabel('GNN output') | ||
plt.show() | ||
|
||
# PLot loss | ||
plt.figure() | ||
plt.plot(inspection_result['total_loss']) | ||
plt.xlabel('Iterations') | ||
plt.ylabel('Total Loss') | ||
plt.show() | ||
|
||
# Plot Edge Mask loss | ||
plt.figure() | ||
plt.plot(inspection_result['edge_mask_loss']) | ||
plt.xlabel('Iterations') | ||
plt.ylabel('Node Mask Loss') | ||
plt.show() | ||
|
||
# Sample 200 mutagenic molecules: | ||
pred = model.predict(xtest)[:,0] | ||
sampled_mutagenic_molecules = np.random.choice(np.argwhere(pred < 0.5)[:,0], 200) | ||
print(sampled_mutagenic_molecules) | ||
|
||
# Generate explanations for all those 50 molecules (this will take a while): | ||
explanations = [] | ||
for i,mol_index in enumerate(sampled_mutagenic_molecules): | ||
explainer.explain([tensor[mol_index:mol_index+1] for tensor in xtest]) | ||
print(i, end=',') | ||
explanations.append(explainer.get_explanation()) | ||
|
||
# We transform the explanation graphs to vectors, in order to apply a cluster algorithm on the explanation vectors: | ||
def explanation_to_vector(explanation): | ||
graph = explanation[0] | ||
bond_matrix = np.zeros((14,14)) | ||
for (u, v, relevance) in graph.edges.data('relevance'): | ||
atom1 = np.argwhere(graph.nodes[u]['features']==1)[0] | ||
atom2 = np.argwhere(graph.nodes[v]['features']==1)[0] | ||
bond_matrix[atom1, atom2] += relevance | ||
bond_matrix[atom2, atom1] += relevance | ||
bond_vector = bond_matrix[np.triu_indices(bond_matrix.shape[0])] | ||
bond_vector = bond_vector / np.sum(bond_vector) | ||
return bond_vector | ||
explanation_vectors = [explanation_to_vector(expl) for expl in explanations] | ||
|
||
|
||
# A dendogram of the explanation vectors: | ||
plt.figure() | ||
linked = linkage(explanation_vectors, 'complete', metric='cityblock') | ||
dendrogram(linked, | ||
orientation='top', | ||
distance_sort='descending', | ||
show_leaf_counts=True) | ||
plt.show() | ||
|
||
|
||
num_clusters = 7 | ||
# Print one representative graph explanation for each cluster: | ||
db = AgglomerativeClustering(n_clusters=num_clusters, affinity='manhattan', linkage='complete').fit(explanation_vectors) | ||
vector_clusters = [] | ||
explanation_clusters = [] | ||
for cluster_ind in range(num_clusters): | ||
plt.figure() | ||
vector_cluster = np.array([explanation_vectors[i] for i in np.argwhere(db.labels_ == cluster_ind)[:,0]]) | ||
vector_clusters.append(vector_cluster) | ||
explanation_cluster = [explanations[i] for i in np.argwhere(db.labels_ == cluster_ind)[:,0]] | ||
explanation_clusters.append(explanation_cluster) | ||
cluster_mean = np.mean(vector_cluster, axis=0) | ||
dist = cdist(np.array([cluster_mean]), vector_cluster)[0] | ||
print(vector_cluster.shape) | ||
ax = plt.subplot() | ||
explainer.present_explanation(explanation_cluster[np.argmin(dist)]) | ||
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters