From f95182f5d71a584b70219ff38cd34c00568902ac Mon Sep 17 00:00:00 2001 From: Kun Zhang Date: Sat, 6 Aug 2022 23:58:42 -0400 Subject: [PATCH 1/2] update experiment & trainer and add example/ogb/molhiv --- cogdl/datasets/ogb.py | 611 +++++++++++++++++++++++----- cogdl/experiments.py | 11 +- cogdl/trainer/trainer.py | 68 ++-- examples/ogb/molhiv/datawrapper.py | 24 ++ examples/ogb/molhiv/gnn.py | 365 +++++++++++++++++ examples/ogb/molhiv/modelwrapper.py | 90 ++++ examples/ogb/molhiv/train.py | 65 +++ 7 files changed, 1088 insertions(+), 146 deletions(-) create mode 100644 examples/ogb/molhiv/datawrapper.py create mode 100644 examples/ogb/molhiv/gnn.py create mode 100644 examples/ogb/molhiv/modelwrapper.py create mode 100644 examples/ogb/molhiv/train.py diff --git a/cogdl/datasets/ogb.py b/cogdl/datasets/ogb.py index 1df6e55a..4b3a333a 100644 --- a/cogdl/datasets/ogb.py +++ b/cogdl/datasets/ogb.py @@ -1,10 +1,13 @@ +import copy import os import torch - +import numpy as np from ogb.nodeproppred import NodePropPredDataset from ogb.nodeproppred import Evaluator as NodeEvaluator from ogb.graphproppred import GraphPropPredDataset -from ogb.linkproppred import LinkPropPredDataset +from ogb.graphproppred import Evaluator as GraphEvaluator +from ogb.lsc import PCQM4MDataset, PCQM4Mv2Dataset, PCQM4MEvaluator, PCQM4Mv2Evaluator +from ogb.utils import smiles2graph from cogdl.data import Dataset, Graph, DataLoader from cogdl.utils import CrossEntropyLoss, Accuracy, remove_self_loops, coalesce, BCEWithLogitsLoss @@ -154,43 +157,234 @@ def __init__(self, data_path="data"): super(OGBPapers100MDataset, self).__init__(data_path, dataset) +class MultiGraph: + def __init__(self, x_l, y_l, edge_index_l, edge_attr_l, x_s, y_s, e_s, n_s): + self._x = torch.cat(x_l, dim=0) if x_l else None + self._y = torch.cat(y_l, dim=0) if y_l else None + self._edge_index = torch.cat(edge_index_l, dim=0) if edge_index_l else None + self._edge_attr = torch.cat(edge_attr_l, dim=0) if edge_attr_l else None + self.x_s, self.y_s, self.e_s, self.n_s = x_s, y_s, e_s, n_s + self._indices = None + + def __getitem__(self, n): + n = self.indices()[n] + x = self._x[self.x_s[n]:self.x_s[n+1]] if self._x is not None else None + y = self._y[self.y_s[n]:self.y_s[n+1]] if self._y is not None else None + edge_index = self._edge_index[self.e_s[n]:self.e_s[n+1]] if self._edge_index is not None else None + edge_attr = self._edge_attr[self.e_s[n]:self.e_s[n+1]] if self._edge_attr is not None else None + + data = Graph( + x=x, + edge_index = edge_index.t(), + edge_attr=edge_attr, + y=y, + ) + data.num_nodes = self.n_s[n] + return data + + def __len__(self): + return len(self._indices) if self._indices is not None else len(self.n_s) + + @property + def edge_index(self): + row, col = self._edge_index + return (row, col) + + def indices(self): + return range(len(self.n_s)) if self._indices is None else self._indices + +class MultiGraphCode: + def __init__(self, x_l, y_l, edge_index_l, edge_attr_l, x_s, y_s, e_s, n_s): + self._x = torch.cat(x_l, dim=0) if x_l else None + self._y = torch.cat(y_l, dim=0) if y_l else None + self._edge_index = torch.cat(edge_index_l, dim=0) if edge_index_l else None + self._edge_attr = torch.cat(edge_attr_l, dim=0) if edge_attr_l else None + self.x_s, self.y_s, self.e_s, self.n_s = x_s, y_s, e_s, n_s + self._indices = None + + def __getitem__(self, n): + n = self.indices()[n] + x = self._x[self.x_s[n]:self.x_s[n+1]] if self._x is not None else None + y = self._y[self.y_s[n]:self.y_s[n+1]] if self._y is not None else None + edge_index = self._edge_index[self.e_s[n]:self.e_s[n+1]] if self._edge_index is not None else None + edge_attr = self._edge_attr[self.e_s[n]:self.e_s[n+1]] if self._edge_attr is not None else None + + num_nodes = x.shape[0] + row, col = edge_index[:, 0], edge_index[:, 1] + zero, one = torch.zeros_like(row), torch.ones_like(row) + + edge_index_ast = torch.stack([row, col], dim=0) + edge_attr_ast = torch.stack([zero, zero], dim=0).t() + + edge_index_ast_inverse = torch.stack([col, row], dim=0) + edge_attr_ast_inverse = torch.stack([zero, one], dim=0).t() + + node_is_attributed = x[:, 2].clone() + node_dfs_order = x[:, 3].clone + attributed_node_idx_in_dfs_order = torch.where(node_is_attributed.view(-1,) == 1)[0] + + edge_index_nextoken = torch.stack([attributed_node_idx_in_dfs_order[:-1], attributed_node_idx_in_dfs_order[1:]], dim = 0) + edge_attr_nextoken = torch.cat([torch.ones(edge_index_nextoken.size(1), 1), torch.zeros(edge_index_nextoken.size(1), 1)], dim = 1) + + edge_index_nextoken_inverse = torch.stack([edge_index_nextoken[1], edge_index_nextoken[0]], dim = 0) + edge_attr_nextoken_inverse = torch.ones((edge_index_nextoken.size(1), 2)) + + edge_index = torch.cat([edge_index_ast, edge_index_ast_inverse, edge_index_nextoken, edge_index_nextoken_inverse], dim=1) + edge_attr = torch.cat([edge_attr_ast, edge_attr_ast_inverse, edge_attr_nextoken, edge_attr_nextoken_inverse], dim=0).to(torch.float32) + + # edge_index = torch.cat([edge_index_ast, edge_index_ast_inverse], dim=1) + # edge_attr = torch.cat([edge_attr_ast, edge_attr_ast_inverse], dim=0).to(torch.float32) + + data = Graph( + x=x, + edge_index = edge_index, + edge_attr=edge_attr, + y=y, + ) + data.num_nodes = self.n_s[n] + + return data + + def __len__(self): + return len(self._indices) if self._indices is not None else len(self.n_s) + + @property + def edge_index(self): + row, col = self._edge_index + return (row, col) + + def indices(self): + return range(len(self.n_s)) if self._indices is None else self._indices + class OGBGDataset(Dataset): def __init__(self, root, name): + name = name.replace("-", "_") + root = os.path.join(root, name) + self.root, self.name = root, name super(OGBGDataset, self).__init__(root) - self.name = name - self.dataset = GraphPropPredDataset(self.name, root) - - self.data = [] - self.all_nodes = 0 - self.all_edges = 0 - for i in range(len(self.dataset.graphs)): - graph, label = self.dataset[i] - data = Graph( - x=torch.tensor(graph["node_feat"], dtype=torch.float), - edge_index=torch.tensor(graph["edge_index"]), - edge_attr=None if "edge_feat" not in graph else torch.tensor(graph["edge_feat"], dtype=torch.float), - y=torch.tensor(label), - ) - data.num_nodes = graph["num_nodes"] - self.data.append(data) - - self.all_nodes += graph["num_nodes"] - self.all_edges += graph["edge_index"].shape[1] + self.data, self.all_nodes, self.all_edges, self.transform, self.split_index = torch.load(self.processed_paths[0]) + self._indices = None + + def get_subset(self, subset): + # datalist = [] + # for idx in subset: + # datalist.append(self.data[idx]) + # return datalist + data = copy.copy(self.data) + data._indices = subset + return data - self.transform = None + def get(self, idx): + return self.data[idx] + + def _download(self): + pass + + def process(self): + name = self.name.replace("_", "-") + dataset = GraphPropPredDataset(name, self.root) + + if name == 'ogbg-molhiv': + x_dtype, attr_type = torch.long, torch.long + elif name == 'ogbg-molpcba': + x_dtype, attr_type = torch.long, torch.long + elif name == 'ogbg-ppa': + x_dtype, attr_type = torch.long, torch.float32 + elif name == 'ogbg-code2': + x_dtype, attr_type = torch.long, torch.float32 + + all_nodes, all_edges = 0, 0 + + x_l, y_l, edge_index_l, edge_attr_l = [], [], [], [] + x_s, y_s, e_s, n_s = [0], [0], [0], [] + + label_l = [] + + for i in range(len(dataset.graphs)): + graph, label = dataset[i] + + if "node_feat" in graph and graph["node_feat"] is not None: + x = torch.tensor(graph["node_feat"], dtype=x_dtype) + else: + x = torch.zeros(graph["num_nodes"], dtype=x_dtype) + + if name == 'ogbg-code2': + x = torch.cat([x, torch.tensor(graph["node_is_attributed"], dtype=x_dtype)], dim=1) + x = torch.cat([x, torch.tensor(graph["node_dfs_order"], dtype=x_dtype)], dim=1) + x = torch.cat([x, torch.tensor(graph["node_depth"], dtype=x_dtype)], dim=1) + label_l.append(label) + y = torch.zeros((1, ), dtype=torch.long) + else: + y = torch.tensor(label) + + edge_index = torch.tensor(graph["edge_index"]).t() + edge_attr = torch.tensor(graph["edge_feat"], dtype=attr_type) if "edge_feat" in graph and graph["edge_feat"] is not None else None + + if x is None: + x_l = None + else: + x_l.append(x) + x_s.append(x_s[-1] + x.shape[0]) + + if y is None: + y_l = None + else: + y_l.append(y) + y_s.append(y_s[-1] + y.shape[0]) + + if edge_index is None: + edge_index_l = None + else: + edge_index_l.append(edge_index) + e_s.append(e_s[-1] + edge_index.shape[0]) + + if edge_attr is None: + edge_attr_l = None + else: + edge_attr_l.append(edge_attr) + + n_s.append(graph["num_nodes"]) + all_nodes += graph["num_nodes"] + all_edges += graph["edge_index"].shape[1] + + transform = None + if name == 'ogbg-code2': + data = MultiGraphCode(x_l, y_l, edge_index_l, edge_attr_l, x_s, y_s, e_s, n_s) + else: + data = MultiGraph(x_l, y_l, edge_index_l, edge_attr_l, x_s, y_s, e_s, n_s) + split_index = dataset.get_idx_split() + + if name == 'ogbg-code2': + data.label = label_l + + torch.save([data, all_nodes, all_edges, transform, split_index], self.processed_paths[0]) + + @property + def processed_file_names(self): + return "data.pt" + + @property + def num_classes(self): + return int(self.data[0].y.shape[-1]) - def get_loader(self, args): - split_index = self.dataset.get_idx_split() - train_loader = DataLoader(self.get_subset(split_index["train"]), batch_size=args.batch_size, shuffle=True) - valid_loader = DataLoader(self.get_subset(split_index["valid"]), batch_size=args.batch_size, shuffle=False) - test_loader = DataLoader(self.get_subset(split_index["test"]), batch_size=args.batch_size, shuffle=False) - return train_loader, valid_loader, test_loader +class OGBGLSCDataset(Dataset): + def __init__(self, root, name): + name = name.replace("-", "_") + root = os.path.join(root, name) + self.root, self.name = root, name + super(OGBGLSCDataset, self).__init__(root) + + self.data, self.all_nodes, self.all_edges, self.transform, self.split_index = torch.load(self.processed_paths[0]) + def get_subset(self, subset): - datalist = [] - for idx in subset: - datalist.append(self.data[idx]) - return datalist + # datalist = [] + # for idx in subset: + # datalist.append(self.data[idx]) + # return datalist + data = copy.copy(self.data) + data._indices = subset + return data def get(self, idx): return self.data[idx] @@ -198,122 +392,317 @@ def get(self, idx): def _download(self): pass - def _process(self): - pass + def process(self): + name = self.name.replace("_", "-") + if name == 'ogbg-pcqm4m': + dataset = PCQM4MDataset(root = self.root, smiles2graph = smiles2graph) + elif name == 'ogbg-pcqm4mv2': + dataset = PCQM4Mv2Dataset(root = self.root, smiles2graph = smiles2graph) - @property - def num_classes(self): - return int(self.dataset.num_classes) + all_nodes, all_edges = 0, 0 + + x_l, y_l, edge_index_l, edge_attr_l = [], [], [], [] + x_s, y_s, e_s, n_s = [0], [0], [0], [] + + for i in range(len(dataset.graphs)): + graph, label = dataset[i] + + x_dtype, attr_type = torch.long, torch.long + + if "node_feat" in graph and graph["node_feat"] is not None: + x = torch.tensor(graph["node_feat"], dtype=x_dtype) + else: + x = torch.zeros(graph["num_nodes"], dtype=x_dtype) + y = torch.tensor([label]) if name != 'ogbg-code2' else None + edge_index = torch.tensor(graph["edge_index"]).t() + edge_attr = torch.tensor(graph["edge_feat"], dtype=attr_type) if "edge_feat" in graph and graph["edge_feat"] is not None else None + + if x is None: + x_l = None + else: + x_l.append(x) + x_s.append(x_s[-1] + x.shape[0]) + + if y is None: + y_l = None + else: + y_l.append(y) + y_s.append(y_s[-1] + y.shape[0]) + + if edge_index is None: + edge_index_l = None + else: + edge_index_l.append(edge_index) + e_s.append(e_s[-1] + edge_index.shape[0]) + + if edge_attr is None: + edge_attr_l = None + else: + edge_attr_l.append(edge_attr) + + n_s.append(graph["num_nodes"]) + all_nodes += graph["num_nodes"] + all_edges += graph["edge_index"].shape[1] + + transform = None + data = MultiGraph(x_l, y_l, edge_index_l, edge_attr_l, x_s, y_s, e_s, n_s) + split_index = dataset.get_idx_split() + torch.save([data, all_nodes, all_edges, transform, split_index], self.processed_paths[0]) -class OGBMolbaceDataset(OGBGDataset): - def __init__(self, data_path="data"): - dataset = "ogbg-molbace" - super(OGBMolbaceDataset, self).__init__(data_path, dataset) + @property + def processed_file_names(self): + return "data.pt" + @property + def num_classes(self): + return int(self.data[0].y.shape[-1]) + +class OGBGEvaluator(object): + def __init__(self, evaluator=None, metric=None, preprocess=None): + super(OGBGEvaluator, self).__init__() + self.evaluator = evaluator + self.metric = metric + self.preprocess = preprocess + self.pred = list() + self.true = list() + + def __call__(self, y_pred, y_true): + self.pred.append(y_pred) + self.true.append(y_true) + + return None + + def evaluate(self): + if len(self.pred) > 0: + pred = torch.cat(self.pred, dim=0) + true = torch.cat(self.true, dim=0) + self.pred = list() + self.true = list() + if self.preprocess is not None: + pred, true = self.preprocess(pred, true) + if self.metric != 'F1': + input_dict = {'y_pred': pred, 'y_true': true} + else: + input_dict = {'seq_pred': pred, 'seq_ref': true} + result_dict = self.evaluator.eval(input_dict) + return result_dict[self.metric] + + warnings.warn("pre-computing list is empty") + return 0 + + def clear(self): + self.tp = list() + self.total = list() + + +# OGB Graph Property Prediction class OGBMolhivDataset(OGBGDataset): def __init__(self, data_path="data"): dataset = "ogbg-molhiv" super(OGBMolhivDataset, self).__init__(data_path, dataset) + + def get_metric_name(self): + return 'rocauc' + + def get_evaluator(self): + name = self.name.replace("_", "-") + return OGBGEvaluator(evaluator=GraphEvaluator(name), metric='rocauc', preprocess=None) + + def get_loss_fn(self): + def loss(input, target): + input = input.to(torch.float32) + target = target.to(torch.float32) + return torch.nn.functional.binary_cross_entropy_with_logits(input, target) + return loss class OGBMolpcbaDataset(OGBGDataset): def __init__(self, data_path="data"): dataset = "ogbg-molpcba" super(OGBMolpcbaDataset, self).__init__(data_path, dataset) + + def get_metric_name(self): + return 'ap' + + def get_evaluator(self): + name = self.name.replace("_", "-") + return OGBGEvaluator(evaluator=GraphEvaluator(name), metric='ap', preprocess=None) + def get_loss_fn(self): + def loss(input, target): + is_labeled = target == target + input = input[is_labeled].to(torch.float32) + target = target[is_labeled].to(torch.float32) + return torch.nn.functional.binary_cross_entropy_with_logits(input, target) + return loss class OGBPpaDataset(OGBGDataset): def __init__(self): dataset = "ogbg-ppa" path = "data" super(OGBPpaDataset, self).__init__(path, dataset) + + def get_metric_name(self): + return 'acc' + + def get_evaluator(self): + name = self.name.replace("_", "-") + def preprocess(input, target): + return torch.argmax(input, dim=1).view(-1,1), target + return OGBGEvaluator(evaluator=GraphEvaluator(name), metric='acc', preprocess=preprocess) + + def get_loss_fn(self): + def loss(input, target): + input = input.to(torch.float32) + target = target.view(-1) + return torch.nn.functional.cross_entropy(input, target) + return loss + + @property + def num_classes(self): + return int(self.data._y.max()+1) class OGBCodeDataset(OGBGDataset): def __init__(self, data_path="data"): - dataset = "ogbg-code" + dataset = "ogbg-code2" super(OGBCodeDataset, self).__init__(data_path, dataset) - - -#This part is for ogbl datasets - -class OGBLDataset(Dataset): - def __init__(self, root, name): - """ - - name (str): name of the dataset - - root (str): root directory to store the dataset folder - """ - self.name = name - - dataset = LinkPropPredDataset(name, root) - graph= dataset[0] - x = torch.tensor(graph["node_feat"]).contiguous() if graph["node_feat"] is not None else None - row, col = graph["edge_index"][0], graph["edge_index"][1] - row = torch.from_numpy(row) - col = torch.from_numpy(col) - edge_index = torch.stack([row, col], dim=0) - edge_attr = torch.as_tensor(graph["edge_feat"]) if graph["edge_feat"] is not None else graph["edge_feat"] - edge_index, edge_attr = remove_self_loops(edge_index, edge_attr) - row = torch.cat([edge_index[0], edge_index[1]]) - col = torch.cat([edge_index[1], edge_index[0]]) - - row, col, _ = coalesce(row, col) - edge_index = torch.stack([row, col], dim=0) - - self.data = Graph(x=x, edge_index=edge_index, edge_attr=edge_attr, y=None) - self.data.num_nodes = graph["num_nodes"] + self.generate_y() + + def generate_y(self, num_vocab=5000, max_seq_len=5): + seq_list = self.data.label + vocab_cnt = {} + vocab_list = [] + for seq in seq_list: + for w in seq: + if w in vocab_cnt: + vocab_cnt[w] += 1 + else: + vocab_cnt[w] = 1 + vocab_list.append(w) + + cnt_list = np.array([vocab_cnt[w] for w in vocab_list]) + topvocab = np.argsort(-cnt_list, kind = 'stable')[:num_vocab] + + print('Coverage of top {} vocabulary:'.format(num_vocab)) + print(float(np.sum(cnt_list[topvocab]))/np.sum(cnt_list)) + + vocab2idx = {vocab_list[vocab_idx]: idx for idx, vocab_idx in enumerate(topvocab)} + idx2vocab = [vocab_list[vocab_idx] for vocab_idx in topvocab] + + vocab2idx['__UNK__'] = num_vocab + idx2vocab.append('__UNK__') + + vocab2idx['__EOS__'] = num_vocab + 1 + idx2vocab.append('__EOS__') + + for idx, vocab in enumerate(idx2vocab): + assert(idx == vocab2idx[vocab]) + + # test that the idx of '__EOS__' is len(idx2vocab) - 1. + # This fact will be used in decode_arr_to_seq, when finding __EOS__ + assert(vocab2idx['__EOS__'] == len(idx2vocab) - 1) + + self.vocab2idx, self.idx2vocab, self.num_vocab, self.max_seq_len = vocab2idx, idx2vocab, num_vocab, max_seq_len + + ys = [] + for seq in seq_list: + y = [] + for i in range(max_seq_len): + w = seq[i] if i < len(seq) else '__EOS__' + v = vocab2idx[w] if w in vocab2idx else vocab2idx['__UNK__'] + y.append(v) + ys.append(y) + + self.data._y = torch.tensor(ys) - def get(self, idx): - assert idx == 0 - return self.data - - def get_loss_fn(self): - return CrossEntropyLoss() + self.train_label = [self.data.label[i] for i in self.split_index["train"]] + self.valid_label = [self.data.label[i] for i in self.split_index["valid"]] + self.test_label = [self.data.label[i] for i in self.split_index["test"]] + def get_metric_name(self): + return 'F1' + def get_evaluator(self): - return Accuracy() + name = self.name.replace("_", "-") + def preprocess(input, target, idx2vocab, train_label, valid_label, test_label): + input = torch.argmax(input, dim=1).view(-1, self.max_seq_len).cpu().numpy().tolist() + preds = [] + for items in input: + pred = [] + for item in items: + if item == self.num_vocab + 1: + break + else: + pred.append(idx2vocab[item]) + preds.append(pred) + + if len(preds) == len(train_label): + label = train_label + elif len(preds) == len(valid_label): + label = valid_label + elif len(preds) == len(test_label): + label = test_label + else: + label = None + + return preds, label + return OGBGEvaluator(evaluator=GraphEvaluator(name), metric='F1', preprocess=lambda x,y : preprocess(x, y, self.idx2vocab, self.train_label, self.valid_label, self.test_label)) - def _download(self): - pass + def get_loss_fn(self): + def loss(input, target): + input = input.to(torch.float32) + target = target.view(-1) + return torch.nn.functional.cross_entropy(input, target) + return loss @property - def processed_file_names(self): - return "data_cogdl.pt" + def num_classes(self): + return len(self.vocab2idx) + +# OGB Large-Scale Challenge + +class OGBPCQM4MDataset(OGBGLSCDataset): + def __init__(self, data_path="data"): + dataset = "ogbg-pcqm4m" + super(OGBPCQM4MDataset, self).__init__(data_path, dataset) - def _process(self): - pass + def get_metric_name(self): + return 'mae' - def get_edge_split(self): - idx = self.dataset.get_edge_split() - train_edge = torch.from_numpy(idx['train']['edge'].T) - val_edge = torch.from_numpy(idx['valid']['edge'].T) - test_edge = torch.from_numpy(idx['test']['edge'].T) - return train_edge, val_edge, test_edge - -class OGBLPpaDataset(OGBLDataset): - def __init__(self, data_path="data"): - dataset = "ogbl-ppa" - super(OGBLPpaDataset, self).__init__(data_path, dataset) - - -class OGBLCollabDataset(OGBLDataset): - def __init__(self, data_path="data"): - dataset = "ogbl-collab" - super(OGBLCollabDataset, self).__init__(data_path, dataset) - + def get_evaluator(self): + def preprocess(input, target): + return input.view(-1), target.view(-1) + return OGBGEvaluator(evaluator=PCQM4MEvaluator(), metric='mae', preprocess=preprocess) -class OGBLDdiDataset(OGBLDataset): - def __init__(self, data_path="data"): - dataset = "ogbl-ddi" - super(OGBLDdiDataset, self).__init__(data_path, dataset) + def get_loss_fn(self): + def loss(input, target): + input = input.to(torch.float32) + target = target.to(torch.float32) + return torch.nn.functional.l1_loss(input, target) + return loss - -class OGBLCitation2Dataset(OGBLDataset): + +class OGBPCQM4Mv2Dataset(OGBGLSCDataset): def __init__(self, data_path="data"): - dataset = "ogbl-citation2" - super(OGBLCitation2Dataset, self).__init__(data_path, dataset) - + dataset = "ogbg-pcqm4mv2" + super(OGBPCQM4Mv2Dataset, self).__init__(data_path, dataset) + self.split_index['test'] = self.split_index['test-dev'] + + def get_metric_name(self): + return 'mae' + + def get_evaluator(self): + def preprocess(input, target): + return input.view(-1), target.view(-1) + return OGBGEvaluator(evaluator=PCQM4Mv2Evaluator(), metric='mae', preprocess=preprocess) + def get_loss_fn(self): + def loss(input, target): + input = input.to(torch.float32) + target = target.to(torch.float32) + return torch.nn.functional.l1_loss(input, target) + return loss diff --git a/cogdl/experiments.py b/cogdl/experiments.py index 585852ea..dbbb43e3 100644 --- a/cogdl/experiments.py +++ b/cogdl/experiments.py @@ -135,14 +135,19 @@ def train(args): # noqa: C901 for key in inspect.signature(dw_class).parameters.keys(): if hasattr(args, key) and key != "dataset": data_wrapper_args[key] = getattr(args, key) + + # setup data_wrapper + dataset_wrapper = dw_class(dataset, **data_wrapper_args) + if hasattr(args, 'scheduler_round') and args.scheduler_round == 'iteration': + args.num_iterations = dataset_wrapper.num_iterations() + else: + args.num_iterations = None + # unworthy code: share `args` between model and model_wrapper for key in inspect.signature(mw_class).parameters.keys(): if hasattr(args, key) and key != "model": model_wrapper_args[key] = getattr(args, key) - # setup data_wrapper - dataset_wrapper = dw_class(dataset, **data_wrapper_args) - args.num_features = dataset.num_features if hasattr(dataset, "num_nodes"): args.num_nodes = dataset.num_nodes diff --git a/cogdl/trainer/trainer.py b/cogdl/trainer/trainer.py index bc6d0f25..b7912d9f 100644 --- a/cogdl/trainer/trainer.py +++ b/cogdl/trainer/trainer.py @@ -19,7 +19,6 @@ from cogdl.trainer.controller import DataController from cogdl.loggers import build_logger from cogdl.data import Graph -from cogdl.utils.grb_utils import adj_preprocess, updateGraph, adj_to_tensor def move_to_device(batch, device): @@ -75,8 +74,6 @@ def __init__( actnn: bool = False, fp16: bool = False, rp_ratio: int = 1, - attack=None, - attack_mode="injection", ): self.epochs = epochs self.nstage = nstage @@ -128,8 +125,6 @@ def __init__( self.eval_data_back_to_cpu = False self.fp16 = fp16 - self.attack = attack - self.attack_mode = attack_mode if actnn: try: @@ -326,11 +321,6 @@ def train(self, rank, model_w, dataset_w): # noqa: C901 self.logger.start() print_str_dict = dict() - if self.attack is not None: - graph = dataset_w.dataset.data - graph_backup = copy.deepcopy(graph) - graph0 = copy.deepcopy(graph) - num_train = torch.sum(graph.train_mask).item() for epoch in epoch_iter: for hook in self.pre_epoch_hooks: hook(self) @@ -343,23 +333,6 @@ def train(self, rank, model_w, dataset_w): # noqa: C901 train_dataset.shuffle() training_loss = self.train_step(model_w, train_loader, optimizers, lr_schedulers, rank, scaler) - if self.attack is not None: - if self.attack_mode == "injection": - graph0.test_mask = graph0.train_mask - else: - graph0.test_mask[torch.where(graph0.train_mask)[0].multinomial(int(num_train * 0.01))] = True - graph_attack = self.attack.attack(model=model_w.model, graph=graph0, adj_norm_func=None) # todo - adj_attack = graph_attack.to_scipy_csr() - features_attack = graph_attack.x - adj_train = adj_preprocess(adj=adj_attack, adj_norm_func=None, device=rank) # todo - n_inject = graph_attack.num_nodes - graph.num_nodes - updateGraph(graph, adj_train, features_attack) - graph.edge_weight = torch.ones(graph.num_edges, device=rank) - graph.train_mask = torch.cat((graph.train_mask, torch.zeros(n_inject, dtype=bool, device=rank)), 0) - graph.val_mask = torch.cat((graph.val_mask, torch.zeros(n_inject, dtype=bool, device=rank)), 0) - graph.test_mask = torch.cat((graph.test_mask, torch.zeros(n_inject, dtype=bool, device=rank)), 0) - graph.y = torch.cat((graph.y, torch.zeros(n_inject, device=rank)), 0) - graph.grb_adj = adj_to_tensor(adj_train).to(rank) print_str_dict["Epoch"] = epoch print_str_dict["train_loss"] = training_loss @@ -380,8 +353,24 @@ def train(self, rank, model_w, dataset_w): # noqa: C901 patience += 1 if self.early_stopping and patience >= self.patience: break + if 'val_loss' in val_result: + val_loss = val_result['val_loss'] + if type(val_loss)==torch.Tensor: + val_loss = val_loss.item() + if val_loss > 1e-4: + val_loss = round(val_loss, 4) + print_str_dict['val_loss'] = val_loss print_str_dict[f"val_{self.evaluation_metric}"] = monitoring - + + test_loader = dataset_w.on_test_wrapper() + if test_loader is not None and epoch % self.eval_step == 0: + # inductive setting .. + dataset_w.eval() + # do validation in inference device + test_result = self.test(model_w, dataset_w, rank) + if test_result is not None and ('test'+self.monitor[3:]) in test_result: + print_str_dict[f"test_{self.evaluation_metric}"] = test_result['test'+self.monitor[3:]] + if self.distributed_training: if rank == 0: epoch_printer(print_str_dict) @@ -400,8 +389,6 @@ def train(self, rank, model_w, dataset_w): # noqa: C901 if best_model_w is None: best_model_w = copy.deepcopy(model_w) - if self.attack is not None: - dataset_w.dataset.data = graph_backup if self.distributed_training: if rank == 0: @@ -482,6 +469,8 @@ def distributed_test(self, model_w: ModelWrapper, loader, rank, fn): def train_step(self, model_w, train_loader, optimizers, lr_schedulers, device, scaler): model_w.train() losses = [] + losses_sum = 0.0 + losses_cnt = 0 if self.progress_bar == "iteration": train_loader = tqdm(train_loader) @@ -515,9 +504,24 @@ def train_step(self, model_w, train_loader, optimizers, lr_schedulers, device, s scaler.update() losses.append(loss.item()) + + if self.progress_bar == "iteration": + losses_sum += losses[-1] + losses_cnt += 1 + train_loader.set_description("Train loss: %0.4f" % (losses_sum / losses_cnt)) + # print (losses[-1]) + + # print (optimizer) + if lr_schedulers is not None: + for lr_scheduler in lr_schedulers: + if hasattr(lr_scheduler, 'scheduler_round') and lr_scheduler.scheduler_round == 'iteration': + lr_scheduler.step() + if lr_schedulers is not None: - for lr_schedular in lr_schedulers: - lr_schedular.step() + for lr_scheduler in lr_schedulers: + if hasattr(lr_scheduler, 'scheduler_round'): + if lr_scheduler.scheduler_round is None or lr_scheduler.scheduler_round == 'epoch': + lr_scheduler.step() return np.mean(losses) diff --git a/examples/ogb/molhiv/datawrapper.py b/examples/ogb/molhiv/datawrapper.py new file mode 100644 index 00000000..668c4bdb --- /dev/null +++ b/examples/ogb/molhiv/datawrapper.py @@ -0,0 +1,24 @@ +from cogdl.wrappers.data_wrapper import DataWrapper +from cogdl.wrappers.tools.wrapper_utils import node_degree_as_feature, split_dataset +from cogdl.data import DataLoader + + +class GraphClassificationDataWrapper(DataWrapper): + def __init__(self, dataset, batch_size, num_workers, collate_fn=None): + super(GraphClassificationDataWrapper, self).__init__(dataset) + self.dataset = dataset + self.batch_size = batch_size + self.num_workers = num_workers + self.collate_fn = collate_fn + + def train_wrapper(self): + return DataLoader(self.dataset.get_subset(self.dataset.split_index["train"]), batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True, shuffle=True, collate_fn=self.collate_fn) + + def val_wrapper(self): + return DataLoader(self.dataset.get_subset(self.dataset.split_index["valid"]), batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True, shuffle=False, collate_fn=self.collate_fn) + + def test_wrapper(self): + return DataLoader(self.dataset.get_subset(self.dataset.split_index["test"]), batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True, shuffle=False, collate_fn=self.collate_fn) + + def num_iterations(self): + return len(self.train_wrapper()) diff --git a/examples/ogb/molhiv/gnn.py b/examples/ogb/molhiv/gnn.py new file mode 100644 index 00000000..1af39868 --- /dev/null +++ b/examples/ogb/molhiv/gnn.py @@ -0,0 +1,365 @@ +import torch +import torch.nn.functional as F +from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set +from torch_geometric.nn import MessagePassing +from torch_geometric.utils import degree + +from cogdl.models import BaseModel +from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder + +import math + +### node encoder and edge encoder +class ASTNodeEncoder(torch.nn.Module): + ''' + Input: + x: default node feature. the first and second column represents node type and node attributes. + node_feat [0, 1] + node_is_attributed [2] + node_dfs_order [3] + node_depth [4] + + Output: + emb_dim-dimensional vector + + ''' + def __init__(self, emb_dim, num_nodetypes=100, num_nodeattributes=10100, max_depth=20): + super(ASTNodeEncoder, self).__init__() + self.max_depth = max_depth + self.type_encoder = torch.nn.Embedding(num_nodetypes, emb_dim) + self.attribute_encoder = torch.nn.Embedding(num_nodeattributes, emb_dim) + self.depth_encoder = torch.nn.Embedding(self.max_depth + 1, emb_dim) + + def forward(self, x): + depth = x[:, 4].clone() + depth[depth > self.max_depth] = self.max_depth + return self.type_encoder(x[:,0]) + self.attribute_encoder(x[:,1]) + self.depth_encoder(depth) + +def get_node_encoder(emb_dim, dataset_name): + if dataset_name in ['ogbg-molhiv', 'ogbg-molpcba', 'ogbg-pcqm4m', 'ogbg-pcqm4mv2']: + return AtomEncoder(emb_dim = emb_dim) + elif dataset_name == 'ogbg-ppa': + return torch.nn.Embedding(1, emb_dim) + elif dataset_name == 'ogbg-code2': + return ASTNodeEncoder(emb_dim = emb_dim) + +def get_edge_encoder(emb_dim, dataset_name): + if dataset_name in ['ogbg-molhiv', 'ogbg-molpcba', 'ogbg-pcqm4m', 'ogbg-pcqm4mv2']: + return BondEncoder(emb_dim = emb_dim) + elif dataset_name == 'ogbg-ppa': + return torch.nn.Linear(7, emb_dim) + elif dataset_name == 'ogbg-code2': + return torch.nn.Linear(2, emb_dim) + +### GCN convolution along the graph structure +class GCNConv(MessagePassing): + def __init__(self, emb_dim, dataset_name = None): + super(GCNConv, self).__init__(aggr='add') + + self.linear = torch.nn.Linear(emb_dim, emb_dim) + self.root_emb = torch.nn.Embedding(1, emb_dim) + self.edge_encoder = get_edge_encoder(emb_dim, dataset_name) + + + def forward(self, x, edge_index, edge_attr): + x = self.linear(x) + edge_embedding = self.edge_encoder(edge_attr) + + row, col = edge_index + + deg = degree(row, x.size(0), dtype = x.dtype) + 1 + deg_inv_sqrt = deg.pow(-0.5) + deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 + + norm = deg_inv_sqrt[row] * deg_inv_sqrt[col] + + return self.propagate(edge_index, x=x, edge_attr = edge_embedding, norm=norm) + F.relu(x + self.root_emb.weight) * 1./deg.view(-1,1) + + def message(self, x_j, edge_attr, norm): + return norm.view(-1, 1) * F.relu(x_j + edge_attr) + + def update(self, aggr_out): + return aggr_out + + +### GIN convolution along the graph structure +class GINConv(MessagePassing): + def __init__(self, emb_dim, dataset_name = None): + super(GINConv, self).__init__(aggr = "add") + + self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, emb_dim)) + self.eps = torch.nn.Parameter(torch.Tensor([0])) + self.edge_encoder = get_edge_encoder(emb_dim, dataset_name) + + def forward(self, x, edge_index, edge_attr): + edge_embedding = self.edge_encoder(edge_attr) + out = self.mlp((1 + self.eps) *x + self.propagate(edge_index, x=x, edge_attr=edge_embedding)) + + return out + + def message(self, x_j, edge_attr): + return F.relu(x_j + edge_attr) + + def update(self, aggr_out): + return aggr_out + + +### GNN to generate node embedding +class GNN_node(torch.nn.Module): + def __init__(self, num_layer, emb_dim, drop_ratio = 0.5, JK = "last", residual = False, gnn_type = 'gin', dataset_name = None): + ''' + emb_dim (int): node embedding dimensionality + num_layer (int): number of GNN message passing layers + + ''' + super(GNN_node, self).__init__() + self.num_layer = num_layer + self.drop_ratio = drop_ratio + self.JK = JK + ### add residual connection or not + self.residual = residual + self.dataset_name = dataset_name + + if self.num_layer < 2: + raise ValueError("Number of GNN layers must be greater than 1.") + + self.node_encoder = get_node_encoder(emb_dim, dataset_name) + + ###List of GNNs + self.convs = torch.nn.ModuleList() + self.batch_norms = torch.nn.ModuleList() + + for layer in range(num_layer): + if gnn_type == 'gin': + self.convs.append(GINConv(emb_dim, dataset_name)) + elif gnn_type == 'gcn': + self.convs.append(GCNConv(emb_dim, dataset_name)) + else: + raise ValueError('Undefined GNN type called {}'.format(gnn_type)) + + self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim)) + + def forward(self, batched_data): + x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batch + edge_index = torch.cat([edge_index[0].view(1, -1), edge_index[1].view(1, -1)], dim=0) + + ### computing input node embedding + + h_list = [self.node_encoder(x)] + for layer in range(self.num_layer): + + h = self.convs[layer](h_list[layer], edge_index, edge_attr) + h = self.batch_norms[layer](h) + + if layer == self.num_layer - 1: + #remove relu for the last layer + h = F.dropout(h, self.drop_ratio, training = self.training) + else: + h = F.dropout(F.relu(h), self.drop_ratio, training = self.training) + + if self.residual: + h += h_list[layer] + + h_list.append(h) + + ### Different implementations of Jk-concat + if self.JK == "last": + node_representation = h_list[-1] + elif self.JK == "sum": + node_representation = 0 + for layer in range(self.num_layer + 1): + node_representation += h_list[layer] + + return node_representation + + +### Virtual GNN to generate node embedding +class GNN_node_Virtualnode(torch.nn.Module): + """ + Output: + node representations + """ + def __init__(self, num_layer, emb_dim, drop_ratio = 0.5, JK = "last", residual = False, gnn_type = 'gin', dataset_name = None): + ''' + emb_dim (int): node embedding dimensionality + ''' + + super(GNN_node_Virtualnode, self).__init__() + self.num_layer = num_layer + self.drop_ratio = drop_ratio + self.JK = JK + ### add residual connection or not + self.residual = residual + + if self.num_layer < 2: + raise ValueError("Number of GNN layers must be greater than 1.") + + self.node_encoder = get_node_encoder(emb_dim, dataset_name) + + ### set the initial virtual node embedding to 0. + self.virtualnode_embedding = torch.nn.Embedding(1, emb_dim) + torch.nn.init.constant_(self.virtualnode_embedding.weight.data, 0) + + ### List of GNNs + self.convs = torch.nn.ModuleList() + ### batch norms applied to node embeddings + self.batch_norms = torch.nn.ModuleList() + + ### List of MLPs to transform virtual node at every layer + self.mlp_virtualnode_list = torch.nn.ModuleList() + + for layer in range(num_layer): + if gnn_type == 'gin': + self.convs.append(GINConv(emb_dim, dataset_name)) + elif gnn_type == 'gcn': + self.convs.append(GCNConv(emb_dim, dataset_name)) + else: + raise ValueError('Undefined GNN type called {}'.format(gnn_type)) + + self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim)) + + for layer in range(num_layer - 1): + self.mlp_virtualnode_list.append( + torch.nn.Sequential( + torch.nn.Linear(emb_dim, 2*emb_dim), + torch.nn.BatchNorm1d(2*emb_dim), + torch.nn.ReLU(), + torch.nn.Linear(2*emb_dim, emb_dim), + torch.nn.BatchNorm1d(emb_dim), + torch.nn.ReLU() + ) + ) + + def forward(self, batched_data): + x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batch + edge_index = torch.cat([edge_index[0].view(1, -1), edge_index[1].view(1, -1)], dim=0) + + + + ### virtual node embeddings for graphs + virtualnode_embedding = self.virtualnode_embedding(torch.zeros(batch[-1].item() + 1).to(edge_index.dtype).to(edge_index.device)) + + h_list = [self.node_encoder(x)] + for layer in range(self.num_layer): + ### add message from virtual nodes to graph nodes + h_list[layer] = h_list[layer] + virtualnode_embedding[batch] + + ### Message passing among graph nodes + h = self.convs[layer](h_list[layer], edge_index, edge_attr) + + h = self.batch_norms[layer](h) + if layer == self.num_layer - 1: + #remove relu for the last layer + h = F.dropout(h, self.drop_ratio, training = self.training) + else: + h = F.dropout(F.relu(h), self.drop_ratio, training = self.training) + + if self.residual: + h = h + h_list[layer] + + h_list.append(h) + + ### update the virtual nodes + if layer < self.num_layer - 1: + ### add message from graph nodes to virtual nodes + virtualnode_embedding_temp = global_add_pool(h_list[layer], batch) + virtualnode_embedding + ### transform virtual nodes using MLP + + if self.residual: + virtualnode_embedding = virtualnode_embedding + F.dropout(self.mlp_virtualnode_list[layer](virtualnode_embedding_temp), self.drop_ratio, training = self.training) + else: + virtualnode_embedding = F.dropout(self.mlp_virtualnode_list[layer](virtualnode_embedding_temp), self.drop_ratio, training = self.training) + + ### Different implementations of Jk-concat + if self.JK == "last": + node_representation = h_list[-1] + elif self.JK == "sum": + node_representation = 0 + for layer in range(self.num_layer + 1): + node_representation += h_list[layer] + + return node_representation + + +class GNN(BaseModel): + + def __init__(self, num_tasks, num_layer = 5, emb_dim = 300, gnn_type = 'gin', virtual_node = True, residual = False, drop_ratio = 0.5, JK = "last", graph_pooling = "mean", dataset_name = None, num_vocab = 5000, max_seq_len = 5): + ''' + num_tasks (int): number of labels to be predicted + virtual_node (bool): whether to add virtual node or not + ''' + + super(GNN, self).__init__() + + self.num_layer = num_layer + self.drop_ratio = drop_ratio + self.JK = JK + self.emb_dim = emb_dim + self.num_tasks = num_tasks + self.graph_pooling = graph_pooling + self.dataset_name = dataset_name + + # only for code2 + self.num_vocab = num_vocab + self.max_seq_len = max_seq_len + + if self.num_layer < 2: + raise ValueError("Number of GNN layers must be greater than 1.") + + ### GNN to generate node embeddings + if virtual_node: + self.gnn_node = GNN_node_Virtualnode(num_layer, emb_dim, JK = JK, drop_ratio = drop_ratio, residual = residual, gnn_type = gnn_type, dataset_name = dataset_name) + else: + self.gnn_node = GNN_node(num_layer, emb_dim, JK = JK, drop_ratio = drop_ratio, residual = residual, gnn_type = gnn_type, dataset_name = dataset_name) + + + ### Pooling function to generate whole-graph embeddings + if self.graph_pooling == "sum": + self.pool = global_add_pool + elif self.graph_pooling == "mean": + self.pool = global_mean_pool + elif self.graph_pooling == "max": + self.pool = global_max_pool + elif self.graph_pooling == "attention": + self.pool = GlobalAttention(gate_nn = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, 1))) + elif self.graph_pooling == "set2set": + self.pool = Set2Set(emb_dim, processing_steps = 2) + else: + raise ValueError("Invalid graph pooling type.") + + if dataset_name != 'ogbg-code2': + if graph_pooling == "set2set": + self.graph_pred_linear = torch.nn.Linear(2*self.emb_dim, self.num_tasks) + else: + self.graph_pred_linear = torch.nn.Linear(self.emb_dim, self.num_tasks) + else: + self.graph_pred_linear_list = torch.nn.ModuleList() + if graph_pooling == "set2set": + for i in range(max_seq_len): + self.graph_pred_linear_list.append(torch.nn.Linear(2*emb_dim, self.num_tasks)) + else: + for i in range(max_seq_len): + self.graph_pred_linear_list.append(torch.nn.Linear(emb_dim, self.num_tasks)) + + def forward(self, batched_data): + h_node = self.gnn_node(batched_data) + h_graph = self.pool(h_node, batched_data.batch) + + if self.dataset_name == 'ogbg-code2': + pred_list = [] + for i in range(self.max_seq_len): + pred_list.append(self.graph_pred_linear_list[i](h_graph)) + output = torch.cat(pred_list, dim=1).view(h_graph.shape[0] * self.max_seq_len, -1) + else: + output = self.graph_pred_linear(h_graph) + + if self.dataset_name == 'ogbg-pcqm4m' or self.dataset_name == 'ogbg-pcqm4mv2': + if self.training: + return output + else: + # At inference time, we clamp the value between 0 and 20 + return torch.clamp(output, min=0, max=20) + else: + return output + + diff --git a/examples/ogb/molhiv/modelwrapper.py b/examples/ogb/molhiv/modelwrapper.py new file mode 100644 index 00000000..f6e2b7bd --- /dev/null +++ b/examples/ogb/molhiv/modelwrapper.py @@ -0,0 +1,90 @@ +import torch + +from cogdl.wrappers.model_wrapper import ModelWrapper + +from torch.optim.lr_scheduler import StepLR, LambdaLR + +def PolynomialDecayLR(step_count, warmup_updates, tot_updates, begin_lr, end_lr, power): + # print ('step_count, warmup_updates, tot_updates, begin_lr, end_lr, power', step_count, warmup_updates, tot_updates, begin_lr, end_lr, power) + step_count += 1 + if step_count <= warmup_updates: + warmup_factor = step_count / float(warmup_updates) + lr = warmup_factor * begin_lr + elif step_count >= tot_updates: + lr = end_lr + else: + pct_remaining = 1 - (step_count - warmup_updates) / (tot_updates - warmup_updates) + lr = (begin_lr - end_lr) * pct_remaining ** power + end_lr + return lr + + +class GraphClassificationModelWrapper(ModelWrapper): + def __init__(self, model, optimizer_cfg, metric_name, scheduler_type, scheduler_round, num_iterations, warmup_epochs, epochs, lr, end_lr): + super(GraphClassificationModelWrapper, self).__init__() + self.model = model + self.optimizer_cfg = optimizer_cfg + self.metric_name = metric_name + + self.scheduler_type = scheduler_type + self.scheduler_round = scheduler_round + self.num_iterations = num_iterations + self.warmup_epochs, self.epochs = warmup_epochs, epochs + self.begin_lr, self.end_lr = lr, end_lr + + def train_step(self, batch): + pred = self.model(batch) + target = batch.y.view(pred.shape[0], -1) + loss = self.default_loss_fn(pred, target) + + return loss + + def val_step(self, batch): + pred = self.model(batch) + target = batch.y.view(pred.shape[0], -1) + val_loss = self.default_loss_fn(pred, target) + metric = self.evaluate(pred, target, metric="auto") + + self.note("val_loss", val_loss) + self.note("val_metric", metric) + + def test_step(self, batch): + pred = self.model(batch) + target = batch.y.view(pred.shape[0], -1) + test_loss = self.default_loss_fn(pred, target) + metric = self.evaluate(pred, target, metric="auto") + + self.note("test_loss", test_loss) + self.note("test_metric", metric) + + def setup_optimizer(self): + cfg = self.optimizer_cfg + + if self.scheduler_type == 'StepLR': + optimizer = torch.optim.Adam(self.model.parameters(), lr=cfg["lr"], weight_decay=cfg["weight_decay"]) + scheduler = StepLR(optimizer, step_size=30, gamma=0.25) + return optimizer, scheduler + + if self.scheduler_type == 'PolynomialDecayLR': + if self.scheduler_round == 'iteration': + + optimizer = torch.optim.Adam(self.model.parameters(), lr=1.0, weight_decay=cfg["weight_decay"]) + warmup_updates = self.warmup_epochs * self.num_iterations + tot_updates = self.epochs * self.num_iterations + begin_lr, end_lr, power = self.begin_lr, self.end_lr, 1.0 + scheduler = LambdaLR( + optimizer, + lr_lambda=lambda x: PolynomialDecayLR( + x, warmup_updates, tot_updates, begin_lr, end_lr, power + ) + ) + scheduler.scheduler_round = 'iteration' + return optimizer, scheduler + + optimizer = torch.optim.Adam(self.model.parameters(), lr=cfg["lr"], weight_decay=cfg["weight_decay"]) + return optimizer + + + def set_early_stopping(self): + if self.metric_name == 'mae': + return "val_metric", "<" + return "val_metric", ">" diff --git a/examples/ogb/molhiv/train.py b/examples/ogb/molhiv/train.py new file mode 100644 index 00000000..51bcfea7 --- /dev/null +++ b/examples/ogb/molhiv/train.py @@ -0,0 +1,65 @@ +import argparse +from cogdl import experiment +from cogdl.options import get_parser + +from cogdl.datasets.ogb import OGBMolhivDataset + +from modelwrapper import GraphClassificationModelWrapper +from datawrapper import GraphClassificationDataWrapper + +from gnn import GNN + +parser = get_parser() + +parser.add_argument('--dataset', type=str, default="ogbg-molhiv", help='dataset name (default: ogbg-molhiv)') +parser.add_argument('--gnn', type=str, default='gin-virtual', help='GNN gin, gin-virtual, or gcn, or gcn-virtual (default: gin-virtual)') +parser.add_argument('--num_layer', type=int, default=5, help='number of GNN message passing layers (default: 5)') +parser.add_argument('--emb_dim', type=int, default=300, help='dimensionality of hidden units in GNNs (default: 300)') +parser.add_argument('--drop_ratio', type=float, default=0.5, help='dropout ratio (default: 0.5)') + +parser.add_argument("--scheduler-type", type=str, default=None, choices=[None, 'StepLR', 'PolynomialDecayLR']) +parser.add_argument("--scheduler-round", type=str, default='epoch', choices=['epoch', 'iteration']) +parser.add_argument("--lr", type=float, default=0.001) +parser.add_argument('--end_lr', type=float, default=1e-9) +parser.add_argument("--weight-decay", type=float, default=0.0) + +parser.add_argument('--epochs', type=int, default=100, help='number of epochs to train (default: 100)') +parser.add_argument('--warmup_epochs', type=int, default=6, help='number of epochs to warmup (default: 6)') +parser.add_argument('--batch_size', type=int, default=256, help='input batch size for training (default: 256)') +parser.add_argument('--num_workers', type=int, default=4, help='number of workers (default: 4)') +parser.add_argument("--progress-bar", type=str, default='iteration', choices=['epoch', 'iteration']) + +args = parser.parse_args() + +args.mw = GraphClassificationModelWrapper +args.dw = GraphClassificationDataWrapper + +dataset = OGBMolhivDataset() + +args.metric_name = dataset.get_metric_name() + +if args.gnn == 'gin': + model = GNN(gnn_type = 'gin', num_tasks = dataset.num_classes, num_layer = args.num_layer, emb_dim = args.emb_dim, drop_ratio = args.drop_ratio, virtual_node = False, dataset_name=args.dataset) +elif args.gnn == 'gin-virtual': + model = GNN(gnn_type = 'gin', num_tasks = dataset.num_classes, num_layer = args.num_layer, emb_dim = args.emb_dim, drop_ratio = args.drop_ratio, virtual_node = True, dataset_name=args.dataset) +elif args.gnn == 'gcn': + model = GNN(gnn_type = 'gcn', num_tasks = dataset.num_classes, num_layer = args.num_layer, emb_dim = args.emb_dim, drop_ratio = args.drop_ratio, virtual_node = False, dataset_name=args.dataset) +elif args.gnn == 'gcn-virtual': + model = GNN(gnn_type = 'gcn', num_tasks = dataset.num_classes, num_layer = args.num_layer, emb_dim = args.emb_dim, drop_ratio = args.drop_ratio, virtual_node = True, dataset_name=args.dataset) +else: + raise ValueError('Invalid GNN type') + +dataset_name = args.dataset + +experiment( + dataset = dataset, + model = model, + args = args, +) + +""" +Result: +| Variant | test__metric | val__metric | +|---------------------------|----------------|---------------| +| (OGBMolhivDataset, 'GNN') | 0.7706±0.0000 | 0.8409±0.0000 | +""" \ No newline at end of file From e601b2f5e0bf03276561e31033e67565f9734f31 Mon Sep 17 00:00:00 2001 From: Kun Zhang Date: Sun, 7 Aug 2022 07:12:41 -0400 Subject: [PATCH 2/2] fix bugs in ogb.py --- cogdl/datasets/ogb.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/cogdl/datasets/ogb.py b/cogdl/datasets/ogb.py index 4b3a333a..fd4d9637 100644 --- a/cogdl/datasets/ogb.py +++ b/cogdl/datasets/ogb.py @@ -12,6 +12,7 @@ from cogdl.data import Dataset, Graph, DataLoader from cogdl.utils import CrossEntropyLoss, Accuracy, remove_self_loops, coalesce, BCEWithLogitsLoss +import warnings class OGBNDataset(Dataset): def __init__(self, root, name, transform=None): @@ -209,7 +210,6 @@ def __getitem__(self, n): edge_index = self._edge_index[self.e_s[n]:self.e_s[n+1]] if self._edge_index is not None else None edge_attr = self._edge_attr[self.e_s[n]:self.e_s[n+1]] if self._edge_attr is not None else None - num_nodes = x.shape[0] row, col = edge_index[:, 0], edge_index[:, 1] zero, one = torch.zeros_like(row), torch.ones_like(row) @@ -220,7 +220,6 @@ def __getitem__(self, n): edge_attr_ast_inverse = torch.stack([zero, one], dim=0).t() node_is_attributed = x[:, 2].clone() - node_dfs_order = x[:, 3].clone attributed_node_idx_in_dfs_order = torch.where(node_is_attributed.view(-1,) == 1)[0] edge_index_nextoken = torch.stack([attributed_node_idx_in_dfs_order[:-1], attributed_node_idx_in_dfs_order[1:]], dim = 0)