Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
benedekrozemberczki authored Dec 9, 2019
1 parent f87c28b commit 39ddb53
Show file tree
Hide file tree
Showing 5 changed files with 162 additions and 37 deletions.
35 changes: 26 additions & 9 deletions src/layers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""NGCN and DenseNGCN layers."""

import math
import torch
from torch_sparse import spmm
Expand Down Expand Up @@ -50,11 +52,21 @@ def forward(self, normalized_adjacency_matrix, features):
:param features: Feature matrix.
:return base_features: Convolved features.
"""
base_features = spmm(features["indices"], features["values"], features["dimensions"][0], self.weight_matrix) + self.bias
base_features = torch.nn.functional.dropout(base_features, p = self.dropout_rate, training = self.training)
base_features = spmm(features["indices"], features["values"],
features["dimensions"][0], self.weight_matrix)

base_features = base_features + self.bias

base_features = torch.nn.functional.dropout(base_features,
p=self.dropout_rate,
training=self.training)

base_features = torch.nn.functional.relu(base_features)
for iteration in range(self.iterations-1):
base_features = spmm(normalized_adjacency_matrix["indices"], normalized_adjacency_matrix["values"], base_features.shape[0], base_features)
for _ in range(self.iterations-1):
base_features = spmm(normalized_adjacency_matrix["indices"],
normalized_adjacency_matrix["values"],
base_features.shape[0],
base_features)
return base_features

class DenseNGCNLayer(torch.nn.Module):
Expand Down Expand Up @@ -95,11 +107,16 @@ def forward(self, normalized_adjacency_matrix, features):
:param features: Feature matrix.
:return base_features: Convolved features.
"""
base_features = torch.mm(features, self.weight_matrix)
base_features = torch.nn.functional.dropout(base_features, p = self.dropout_rate, training = self.training)
for iteration in range(self.iterations-1):
base_features = spmm(normalized_adjacency_matrix["indices"], normalized_adjacency_matrix["values"], base_features.shape[0], base_features)
base_features = base_features + self.bias
base_features = torch.mm(features, self.weight_matrix)
base_features = torch.nn.functional.dropout(base_features,
p=self.dropout_rate,
training=self.training)
for _ in range(self.iterations-1):
base_features = spmm(normalized_adjacency_matrix["indices"],
normalized_adjacency_matrix["values"],
base_features.shape[0],
base_features)
base_features = base_features + self.bias
return base_features

class ListModule(torch.nn.Module):
Expand Down
7 changes: 5 additions & 2 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
"""Running MixHop or N-GCN."""

import torch
from parser import parameter_parser
from param_parser import parameter_parser
from trainer_and_networks import Trainer
from utils import tab_printer, graph_reader, feature_reader, target_reader

def main():
"""
Parsing command line parameters, reading data, fitting an NGCN and scoring the model.
Parsing command line parameters, reading data.
Fitting an NGCN and scoring the model.
"""
args = parameter_parser()
torch.manual_seed(args.seed)
Expand Down
95 changes: 95 additions & 0 deletions src/param_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""Parameter parsing."""

import argparse

def parameter_parser():
"""
A method to parse up command line parameters. By default it trains on the Cora dataset.
The default hyperparameters give a good quality representation without grid search.
"""
parser = argparse.ArgumentParser(description="Run MixHop/N-GCN.")

parser.add_argument("--edge-path",
nargs="?",
default="./input/cora_edges.csv",
help="Edge list csv.")

parser.add_argument("--features-path",
nargs="?",
default="./input/cora_features.json",
help="Features json.")

parser.add_argument("--target-path",
nargs="?",
default="./input/cora_target.csv",
help="Target classes csv.")

parser.add_argument("--model",
nargs="?",
default="mixhop",
help="Target classes csv.")

parser.add_argument("--epochs",
type=int,
default=2000,
help="Number of training epochs. Default is 2000.")

parser.add_argument("--seed",
type=int,
default=42,
help="Random seed for train-test split. Default is 42.")

parser.add_argument("--early-stopping",
type=int,
default=10,
help="Number of early stopping rounds. Default is 10.")

parser.add_argument("--training-size",
type=int,
default=1500,
help="Training set size. Default is 1500.")

parser.add_argument("--validation-size",
type=int,
default=500,
help="Validation set size. Default is 500.")

parser.add_argument("--dropout",
type=float,
default=0.5,
help="Dropout parameter. Default is 0.5.")

parser.add_argument("--learning-rate",
type=float,
default=0.01,
help="Learning rate. Default is 0.01.")

parser.add_argument("--cut-off",
type=float,
default=0.1,
help="Weight cut-off. Default is 0.1.")

parser.add_argument("--lambd",
type=float,
default=0.0005,
help="L2 regularization coefficient. Default is 0.0005.")

parser.add_argument("--layers-1",
nargs="+",
type=int,
help="Layer dimensions separated by space (top). E.g. 200 20.")

parser.add_argument("--layers-2",
nargs="+",
type=int,
help="Layer dimensions separated by space (bottom). E.g. 200 200.")

parser.add_argument("--budget",
type=int,
default=60,
help="Architecture neuron allocation budget. Default is 60.")

parser.set_defaults(layers_1=[200, 200, 200])
parser.set_defaults(layers_2=[200, 200, 200])

return parser.parse_args()
27 changes: 14 additions & 13 deletions src/trainer_and_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ def forward(self, normalized_adjacency_matrix, features):
:param features: Feature matrix.
:return predictions: Label predictions.
"""
abstract_features = torch.cat([self.main_layers[i](normalized_adjacency_matrix, features) for i in range(self.order)],dim=1)
predictions = torch.nn.functional.log_softmax(self.fully_connected(abstract_features),dim=1)
abstract_features = [self.main_layers[i](normalized_adjacency_matrix, features) for i in range(self.order)]
abstract_features = torch.cat(abstract_features, dim=1)
predictions = torch.nn.functional.log_softmax(self.fully_connected(abstract_features), dim=1)
return predictions

class MixHopNetwork(torch.nn.Module):
Expand Down Expand Up @@ -104,9 +105,9 @@ def forward(self, normalized_adjacency_matrix, features):
:param features: Feature matrix.
:return predictions: Label predictions.
"""
abstract_features_1 = torch.cat([self.upper_layers[i](normalized_adjacency_matrix, features) for i in range(self.order_1)],dim=1)
abstract_features_2 = torch.cat([self.bottom_layers[i](normalized_adjacency_matrix, abstract_features_1) for i in range(self.order_2)],dim=1)
predictions = torch.nn.functional.log_softmax(self.fully_connected(abstract_features_2),dim=1)
abstract_features_1 = torch.cat([self.upper_layers[i](normalized_adjacency_matrix, features) for i in range(self.order_1)], dim=1)
abstract_features_2 = torch.cat([self.bottom_layers[i](normalized_adjacency_matrix, abstract_features_1) for i in range(self.order_2)], dim=1)
predictions = torch.nn.functional.log_softmax(self.fully_connected(abstract_features_2), dim=1)
return predictions

class Trainer(object):
Expand Down Expand Up @@ -156,7 +157,7 @@ def setup_model(self):
self.model = MixHopNetwork(self.args, self.feature_number, self.class_number)
else:
self.model = NGCNNetwork(self.args, self.feature_number, self.class_number)

def fit(self):
"""
Fitting a neural network with early stopping.
Expand All @@ -166,7 +167,7 @@ def fit(self):
epochs = trange(self.args.epochs, desc="Accuracy")
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.args.learning_rate)
self.model.train()
for epoch in epochs:
for _ in epochs:
self.optimizer.zero_grad()
prediction = self.model(self.propagation_matrix, self.features)
loss = torch.nn.functional.nll_loss(prediction[self.train_nodes], self.target[self.train_nodes])
Expand All @@ -177,18 +178,18 @@ def fit(self):
loss.backward()
self.optimizer.step()
new_accuracy = self.score(self.validation_nodes)
epochs.set_description("Validation Accuracy: %g" % round(new_accuracy,4))
epochs.set_description("Validation Accuracy: %g" % round(new_accuracy, 4))
if new_accuracy < accuracy:
no_improvement = no_improvement + 1
if no_improvement == self.args.early_stopping:
epochs.close()
break
else:
no_improvement = 0
accuracy = new_accuracy
accuracy = new_accuracy
acc = self.score(self.test_nodes)
print("\nTest accuracy: " + str(round(acc,4)) +"\n")
print("\nTest accuracy: " + str(round(acc, 4)) +"\n")

def score(self, indices):
"""
Scoring a neural network.
Expand All @@ -212,14 +213,14 @@ def evaluate_architecture(self):

for layer in self.model.upper_layers:
norms = torch.norm(layer.weight_matrix**2, dim=0)
norms = norms[norms<self.args.cut_off]
norms = norms[norms < self.args.cut_off]
self.layer_sizes["upper"].append(norms.shape[0])

self.layer_sizes["bottom"] = []

for layer in self.model.bottom_layers:
norms = torch.norm(layer.weight_matrix**2, dim=0)
norms = norms[norms<self.args.cut_off]
norms = norms[norms < self.args.cut_off]
self.layer_sizes["bottom"].append(norms.shape[0])

self.layer_sizes["upper"] = [int(self.args.budget*layer_size/sum(self.layer_sizes["upper"])) for layer_size in self.layer_sizes["upper"]]
Expand Down
35 changes: 22 additions & 13 deletions src/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Data reading tools."""

import json
import torch
import numpy as np
Expand All @@ -13,8 +15,9 @@ def tab_printer(args):
"""
args = vars(args)
keys = sorted(args.keys())
t = Texttable()
t.add_rows([["Parameter", "Value"]] + [[k.replace("_"," ").capitalize(),args[k]] for k in keys])
t = Texttable()
t.add_rows([["Parameter", "Value"]])
t.add_rows([[k.replace("_", " ").capitalize(), args[k]] for k in keys])
print(t.draw())

def graph_reader(path):
Expand All @@ -24,7 +27,7 @@ def graph_reader(path):
:return graph: NetworkX object returned.
"""
graph = nx.from_edgelist(pd.read_csv(path).values.tolist())
graph.remove_edges_from(graph.selfloop_edges())
graph.remove_edges_from(list(nx.selfloop_edges(graph)))
return graph

def feature_reader(path):
Expand All @@ -34,15 +37,18 @@ def feature_reader(path):
:return out_features: Dict with index and value tensor.
"""
features = json.load(open(path))
index_1 = [int(k) for k,v in features.items() for fet in v]
index_2 = [int(fet) for k,v in features.items() for fet in v]
values = [1.0]*len(index_1)
nodes = [int(k) for k,v in features.items()]
index_1 = [int(k) for k, v in features.items() for fet in v]
index_2 = [int(fet) for k, v in features.items() for fet in v]
values = [1.0]*len(index_1)
nodes = [int(k) for k, v in features.items()]
node_count = max(nodes)+1
feature_count = max(index_2)+1
features = sparse.coo_matrix((values,(index_1,index_2)), shape=(node_count, feature_count),dtype=np.float32)
features = sparse.coo_matrix((values, (index_1, index_2)),
shape=(node_count, feature_count),
dtype=np.float32)
out_features = dict()
out_features["indices"] = torch.LongTensor(np.concatenate([features.row.reshape(-1,1), features.col.reshape(-1,1)],axis=1).T)
ind = np.concatenate([features.row.reshape(-1, 1), features.col.reshape(-1, 1)], axis=1)
out_features["indices"] = torch.LongTensor(ind.T)
out_features["values"] = torch.FloatTensor(features.data)
out_features["dimensions"] = features.shape
return out_features
Expand All @@ -65,8 +71,10 @@ def create_adjacency_matrix(graph):
index_1 = [edge[0] for edge in graph.edges()] + [edge[1] for edge in graph.edges()]
index_2 = [edge[1] for edge in graph.edges()] + [edge[0] for edge in graph.edges()]
values = [1 for index in index_1]
node_count = max(max(index_1)+1,max(index_2)+1)
A = sparse.coo_matrix((values, (index_1,index_2)),shape=(node_count,node_count),dtype=np.float32)
node_count = max(max(index_1)+1, max(index_2)+1)
A = sparse.coo_matrix((values, (index_1, index_2)),
shape=(node_count, node_count),
dtype=np.float32)
return A

def normalize_adjacency_matrix(A, I):
Expand All @@ -87,13 +95,14 @@ def create_propagator_matrix(graph):
"""
Creating a propagator matrix.
:param graph: NetworkX graph.
:return propagator: Dictionary of matrix indices and values.
:return propagator: Dictionary of matrix indices and values.
"""
A = create_adjacency_matrix(graph)
I = sparse.eye(A.shape[0])
A_tilde_hat = normalize_adjacency_matrix(A, I)
propagator = dict()
A_tilde_hat = sparse.coo_matrix(A_tilde_hat)
propagator["indices"] = torch.LongTensor(np.concatenate([A_tilde_hat.row.reshape(-1,1), A_tilde_hat.col.reshape(-1,1)],axis=1).T)
ind = np.concatenate([A_tilde_hat.row.reshape(-1, 1), A_tilde_hat.col.reshape(-1, 1)], axis=1)
propagator["indices"] = torch.LongTensor(ind.T)
propagator["values"] = torch.FloatTensor(A_tilde_hat.data)
return propagator

0 comments on commit 39ddb53

Please sign in to comment.