Skip to content

Commit

Permalink
added examples
Browse files Browse the repository at this point in the history
  • Loading branch information
aimat committed Apr 10, 2021
1 parent 2a2c102 commit c6d1562
Show file tree
Hide file tree
Showing 4 changed files with 289 additions and 8 deletions.
6 changes: 4 additions & 2 deletions examples/explain_GNNExplain_mutagenicity_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
labels, nodes, edge_indices, edges, atoms = mutagenicity_graph()
for i in range(len(labels)):
# edge_indices[i], edges[i] = add_self_loops_to_edge_indices(edge_indices[i], np.expand_dims(edges[i],axis=-1))
edges[i] = np.expand_dims(edges[i], axis=-1).astype(np.float32)# Make edge feature dimension
edges[i] = np.expand_dims(edges[i], axis=-1).astype(np.float32) # Make edge feature dimension
for i in range(len(labels)):
nodes[i] = np.array(
np.expand_dims(nodes[i],axis=-1) == np.array([[ 1, 3, 6, 7, 8, 9, 11, 15, 16, 17, 19, 20, 35, 53]])
Expand Down Expand Up @@ -210,16 +210,18 @@ def present_explanation(self, explanation, threshold=0.5):
plt.plot(inspection_result['predictions'])
plt.xlabel('Iterations')
plt.ylabel('GNN output')
plt.show()

# loss
plt.figure()
plt.plot(inspection_result['total_loss'])
plt.xlabel('Iterations')
plt.ylabel('Total Loss')
plt.show()

# edge loss
plt.figure()
plt.plot(inspection_result['edge_mask_loss'])
plt.xlabel('Iterations')
plt.ylabel('Node Mask Loss')

plt.show()
6 changes: 4 additions & 2 deletions examples/explain_GNNExplain_mutagenicity_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ def present_explanation(self, explanation, threshold=0.5):
#explainer.explain([tensor[instance_index:instance_index+1] for tensor in xtest], output_to_explain=tf.Variable([0.]))

# Let's look at the explanation the GNNExplainer found:
plt.figure()
explainer.present_explanation(explainer.get_explanation(), threshold=0.5)
plt.show()

Expand All @@ -235,6 +236,7 @@ def present_explanation(self, explanation, threshold=0.5):
# We can also specify output_to_explain to be tf.Variable([0.]).
# This way we can tell the Explainer to explain why a molecule could be mutagenic
# (even for molecules which are classified as most likely non-mutagenic by the GNN):

plt.figure()
explainer.explain([tensor[instance_index:instance_index+1] for tensor in xtest], output_to_explain=tf.Variable([0.]))
explainer.present_explanation(explainer.get_explanation(), threshold=0.5)
explainer.present_explanation(explainer.get_explanation(), threshold=0.5)
plt.show()
277 changes: 277 additions & 0 deletions examples/explain_GNNExplain_mutagenicity_3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,277 @@
import tensorflow as tf
from sklearn.model_selection import train_test_split
from sklearn.cluster import DBSCAN
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import time

from kgcnn.literature.GNNExplain import GNNExplainer, GNNInterface
# from kgcnn.utils.adj import precompute_adjacency_scaled, convert_scaled_adjacency_to_list, add_self_loops_to_edge_indices
from kgcnn.literature.GCN import make_gcn
from kgcnn.utils.data import ragged_tensor_from_nested_numpy
from kgcnn.utils.learning import lr_lin_reduction

from scipy.cluster.hierarchy import dendrogram, linkage
from scipy.spatial.distance import cdist
from sklearn.cluster import AgglomerativeClustering

from kgcnn.data.mutagen.mutagenicity import mutagenicity_graph

labels, nodes, edge_indices, edges, atoms = mutagenicity_graph()
for i in range(len(labels)):
# edge_indices[i], edges[i] = add_self_loops_to_edge_indices(edge_indices[i], np.expand_dims(edges[i],axis=-1))
edges[i] = np.expand_dims(edges[i], axis=-1).astype(np.float32) # Make edge feature dimension
for i in range(len(labels)):
nodes[i] = np.array(
np.expand_dims(nodes[i],axis=-1) == np.array([[ 1, 3, 6, 7, 8, 9, 11, 15, 16, 17, 19, 20, 35, 53]])
, dtype=np.int) # Make One-Hot encoding

# Train Test split
labels_train, labels_test, nodes_train, nodes_test, edges_train, edges_test, edge_indices_train, edge_indices_test = train_test_split(
labels, nodes, edges, edge_indices, train_size=0.8, random_state=1)

# Convert to tf.RaggedTensor or tf.tensor
# adj_matrix copy of the data is generated by ragged_tensor_from_nested_numpy()
nodes_train, edges_train, edge_indices_train = ragged_tensor_from_nested_numpy(
nodes_train), ragged_tensor_from_nested_numpy(edges_train), ragged_tensor_from_nested_numpy(
edge_indices_train)

nodes_test, edges_test, edge_indices_test = ragged_tensor_from_nested_numpy(
nodes_test), ragged_tensor_from_nested_numpy(edges_test), ragged_tensor_from_nested_numpy(
edge_indices_test)

xtrain = nodes_train, edges_train, edge_indices_train
xtest = nodes_test, edges_test, edge_indices_test
ytrain = np.expand_dims(labels_train, axis=-1)
ytest = np.expand_dims(labels_test, axis=-1)

model = make_gcn(
input_node_shape=[None,14],
input_edge_shape=[None, 1],
input_embedd={'input_node_vocab': 55, "input_node_embedd": 64},
# Output
output_embedd={"output_mode": 'graph', "output_type": 'padded'},
output_mlp={"use_bias": [True, True, False], "units": [140, 70, 1], "activation": ['relu', 'relu', 'sigmoid']},
# model specs
depth=3,
gcn_args={"units": 64, "use_bias": True, "activation": "relu", "has_unconnected": True, "is_sorted": False, "pooling_method": 'segment_mean'}
)

# Set learning rate and epochs
learning_rate_start = 1e-3
learning_rate_stop = 1e-4
epo = 150
epomin = 100
epostep = 10

# Compile model with optimizer and loss
optimizer = tf.keras.optimizers.Adam(lr=learning_rate_start)
cbks = tf.keras.callbacks.LearningRateScheduler(lr_lin_reduction(learning_rate_start, learning_rate_stop, epomin, epo))
model.compile(loss='binary_crossentropy',
optimizer=optimizer,
weighted_metrics=['accuracy'])
print(model.summary())

# Start and time training
start = time.process_time()
hist = model.fit(xtrain, ytrain,
epochs=epo,
batch_size=32,
callbacks=[cbks],
validation_freq=epostep,
validation_data=(xtest, ytest),
verbose=2
)
stop = time.process_time()
print("Print Time for taining: ", stop - start)

# Get loss from history
trainlossall = np.array(hist.history['accuracy'])
testlossall = np.array(hist.history['val_accuracy'])
acc_valid = testlossall[-1]

# Plot loss vs epochs
plt.figure()
plt.plot(np.arange(trainlossall.shape[0]), trainlossall, label='Training ACC', c='blue')
plt.plot(np.arange(epostep, epo + epostep, epostep), testlossall, label='Test ACC', c='red')
plt.scatter([trainlossall.shape[0]], [acc_valid], label="{0:0.4f} ".format(acc_valid), c='red')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.title('Interaction Network Loss')
plt.legend(loc='upper right', fontsize='x-large')
plt.savefig('gcn_explain_mutag_3.png')
plt.show()

# We net to implement the ExplainableGNN from GNNInterface

class ExplainableGCN(GNNInterface):

def __init__(self, gnn_model, **kwargs):
super(ExplainableGCN, self).__init__()
self.gnn_model = gnn_model

def predict(self, gnn_input, masking_info=None):
return self.gnn_model(gnn_input, training=False)[0]

def masked_predict(self, gnn_input, edge_mask, feature_mask, node_mask, training=False):
node_input, edge_input, edge_index_input = gnn_input

masked_edge_input = tf.ragged.map_flat_values(tf.math.multiply, edge_input, edge_mask)
masked_feature_input = tf.ragged.map_flat_values(tf.math.multiply, tf.dtypes.cast(node_input, tf.float32),
tf.transpose(feature_mask))
masked_node_feature_input = tf.ragged.map_flat_values(tf.math.multiply, masked_feature_input, node_mask)
masked_pred = \
self.gnn_model([masked_node_feature_input, masked_edge_input, edge_index_input], training=training)[0]

return masked_pred

def get_number_of_nodes(self, gnn_input):
node_input, _, _ = gnn_input
return node_input[0].shape[0]

def get_number_of_node_features(self, gnn_input):
node_input, _, _ = gnn_input
return node_input.shape[2]

def get_number_of_edges(self, gnn_input):
_, edge_input, _ = gnn_input
return edge_input[0].shape[0]

def get_explanation(self, gnn_input, edge_mask, feature_mask, node_mask):
edge_relevance = np.array(edge_mask[:, 0])
node_relevance = np.array(node_mask[:, 0])
feature_relevance = np.array(feature_mask[:, 0])
features = np.array(gnn_input[0][0])
edges = np.array(gnn_input[2][0])
graph = nx.Graph()
for i, f in enumerate(features):
graph.add_node(i, features=f, relevance=node_relevance[i])
for i, e in enumerate(edges):
if edge_relevance is None:
graph.add_edge(e[0], e[1])
else:
graph.add_edge(e[0], e[1], relevance=edge_relevance[i])
return graph, feature_relevance

def present_explanation(self, explanation, threshold=0.5):
graph = explanation[0]
# element_labels = np.array([[ 1, 3, 6, 7, 8, 9, 11, 15, 16, 17, 19, 20, 35, 53]])
element_labels = ['H', 'Li', 'C', 'N', 'O', 'F', 'Na', 'P', 'S', 'Cl', 'K', 'Ca', 'Br', 'I']
important_edges = []
color_map = []
node_color_map = []
node_labels = {}
for (u, v, relevance) in graph.edges.data('relevance'):
relevance = min(relevance + 0.1, 1.0)
color_map.append((0, 0, 0, relevance))
for n, f in graph.nodes.data('features'):
element = np.argmax(f)
r, g, b, a = plt.get_cmap('tab20')(element)
node_color_map.append((r, g, b, graph.nodes[n]['relevance']))
node_labels[n] = (element_labels[element])
if np.all(explanation[1] == 1):
nx.draw_kamada_kawai(graph, edge_color=color_map, labels=node_labels, node_color=node_color_map)
else:
f, axs = plt.subplots(2, figsize=(8, 12))
nx.draw_kamada_kawai(graph, ax=axs[0], edge_color=color_map, labels=node_labels, node_color=node_color_map)
bar_colors = [plt.get_cmap('tab20')(element) for element in np.arange(14)]
axs[1].bar(np.array(element_labels), explanation[1], color=bar_colors)


# Instanciate a Explainable GNN:
explainable_gcn = ExplainableGCN(model)
compile_options = {'loss': 'binary_crossentropy', 'optimizer': tf.keras.optimizers.Adam(lr=0.2)}
fit_options={'epochs': 100, 'batch_size': 1, 'verbose': 0}
gnnexplaineroptimizer_options = {'edge_mask_loss_weight': 0.001,
'edge_mask_norm_ord': 1,
'feature_mask_loss_weight': 0,
'feature_mask_norm_ord': 1,
'node_mask_loss_weight': 0,
'node_mask_norm_ord': 1}

explainer = GNNExplainer(explainable_gcn,
compile_options=compile_options,
fit_options=fit_options,
gnnexplaineroptimizer_options=gnnexplaineroptimizer_options)

inspection_result = explainer.explain([tensor[776:777] for tensor in xtest], inspection=True)
#inspection_result = explainer.explain([tensor[776:777] for tensor in xtest], output_to_explain=tf.Variable([0.]), inspection=True)

explainer.present_explanation(explainer.get_explanation(), threshold=0.5)

# Plot predicion
plt.figure()
plt.plot(inspection_result['predictions'])
plt.xlabel('Iterations')
plt.ylabel('GNN output')
plt.show()

# PLot loss
plt.figure()
plt.plot(inspection_result['total_loss'])
plt.xlabel('Iterations')
plt.ylabel('Total Loss')
plt.show()

# Plot Edge Mask loss
plt.figure()
plt.plot(inspection_result['edge_mask_loss'])
plt.xlabel('Iterations')
plt.ylabel('Node Mask Loss')
plt.show()

# Sample 200 mutagenic molecules:
pred = model.predict(xtest)[:,0]
sampled_mutagenic_molecules = np.random.choice(np.argwhere(pred < 0.5)[:,0], 200)
print(sampled_mutagenic_molecules)

# Generate explanations for all those 50 molecules (this will take a while):
explanations = []
for i,mol_index in enumerate(sampled_mutagenic_molecules):
explainer.explain([tensor[mol_index:mol_index+1] for tensor in xtest])
print(i, end=',')
explanations.append(explainer.get_explanation())

# We transform the explanation graphs to vectors, in order to apply a cluster algorithm on the explanation vectors:
def explanation_to_vector(explanation):
graph = explanation[0]
bond_matrix = np.zeros((14,14))
for (u, v, relevance) in graph.edges.data('relevance'):
atom1 = np.argwhere(graph.nodes[u]['features']==1)[0]
atom2 = np.argwhere(graph.nodes[v]['features']==1)[0]
bond_matrix[atom1, atom2] += relevance
bond_matrix[atom2, atom1] += relevance
bond_vector = bond_matrix[np.triu_indices(bond_matrix.shape[0])]
bond_vector = bond_vector / np.sum(bond_vector)
return bond_vector
explanation_vectors = [explanation_to_vector(expl) for expl in explanations]


# A dendogram of the explanation vectors:
plt.figure()
linked = linkage(explanation_vectors, 'complete', metric='cityblock')
dendrogram(linked,
orientation='top',
distance_sort='descending',
show_leaf_counts=True)
plt.show()


num_clusters = 7
# Print one representative graph explanation for each cluster:
db = AgglomerativeClustering(n_clusters=num_clusters, affinity='manhattan', linkage='complete').fit(explanation_vectors)
vector_clusters = []
explanation_clusters = []
for cluster_ind in range(num_clusters):
plt.figure()
vector_cluster = np.array([explanation_vectors[i] for i in np.argwhere(db.labels_ == cluster_ind)[:,0]])
vector_clusters.append(vector_cluster)
explanation_cluster = [explanations[i] for i in np.argwhere(db.labels_ == cluster_ind)[:,0]]
explanation_clusters.append(explanation_cluster)
cluster_mean = np.mean(vector_cluster, axis=0)
dist = cdist(np.array([cluster_mean]), vector_cluster)[0]
print(vector_cluster.shape)
ax = plt.subplot()
explainer.present_explanation(explanation_cluster[np.argmin(dist)])
plt.show()
8 changes: 4 additions & 4 deletions kgcnn/data/mutagen/mutagenicity.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,10 @@ def mutagenicity_load(path):
test_cons = np.sort(np.unique(edge_indices[i].flatten()))
is_cons = np.zeros_like(cons, dtype=np.bool)
is_cons[test_cons] = True
is_cons[nats == 20] = True
is_cons[nats == 3] = True
is_cons[nats == 19] = True
is_cons[nats == 11] = True
is_cons[nats == 20] = True # Allow to be unconnected
is_cons[nats == 3] = True # Allow to be unconnected
is_cons[nats == 19] = True # Allow to be unconnected
is_cons[nats == 11] = True # Allow to be unconnected
if np.sum(is_cons) != len(cons):
info_list = nodes[i][is_cons == False]
info_list, info_cnt = np.unique(info_list, return_counts=True)
Expand Down

0 comments on commit c6d1562

Please sign in to comment.