diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..7f426ff --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +# cache +*.vscode +*__pycache__ \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..987b84d --- /dev/null +++ b/README.md @@ -0,0 +1,50 @@ +## Improving Molecular Contrastive Learning via Faulty Negative Mitigation and Decomposed Fragment Contrast ## + +This is the offical implementation of iMolCLR: ["Improving Molecular Contrastive Learning via Faulty Negative Mitigation and Decomposed Fragment Contrast"](https://arxiv.org/abs/2202.09346). + +## Getting Started + +### Installation + +Set up conda environment and clone the github repo + +``` +# create a new environment +$ conda create --name imolclr python=3.7 +$ conda activate imolclr + +# install requirements +$ pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 -f https://download.pytorch.org/whl/torch_stable.html +$ pip install torch-geometric==1.6.3 torch-sparse==0.6.9 torch-scatter==2.0.6 -f https://pytorch-geometric.com/whl/torch-1.7.0+cu110.html +$ pip install PyYAML +$ conda install -c conda-forge rdkit=2021.09.1 +$ conda install -c conda-forge tensorboard + +# clone the source code of iMolCLR +$ git clone https://github.com/yuyangw/iMolCLR.git +$ cd iMolCLR +``` + +### Dataset + +You can download the pre-training data and benchmarks used in the paper [here](https://drive.google.com/file/d/1aDtN6Qqddwwn2x612kWz9g0xQcuAtzDE/view?usp=sharing) and extract the zip file under `./data` folder. The data for pre-training can be found in `pubchem-10m-clean.txt`. All the databases for fine-tuning are saved in the folder under the benchmark name. You can also find the benchmarks from [MoleculeNet](https://moleculenet.org/). + +### Pre-training + +To train the iMolCLR, where the configurations are defined in `config.yaml` +``` +$ python imolclr.py +``` + +To monitor the training via tensorboard, run `tensorboard --logdir ckpt/{PATH}` and click the URL http://127.0.0.1:6006/. + +### Fine-tuning + +To fine-tune the iMolCLR pre-trained model on downstream molecular benchmarks, where the configurations are defined in `config_finetune.yaml` +``` +$ python finetune.py +``` + +### Pre-trained model + +We also provide a pre-trained model, which can be found in `ckpt/pretrained`. You can load the model by change the `fine_tune_from` variable in `config_finetune.yaml` to `pretrained`. diff --git a/config.yaml b/config.yaml new file mode 100644 index 0000000..e247e2d --- /dev/null +++ b/config.yaml @@ -0,0 +1,31 @@ +batch_size: 512 # batch size +world_size: 3 # total number of GPUs +backend: nccl # backends of PyTorch +epochs: 50 # total number of epochs +warmup: 10 # warm-up epochs + +eval_every_n_epochs: 1 # validation frequency +resume_from: None # resume training +log_every_n_steps: 200 # print training log frequency + +optim: + lr: 0.0005 # initial learning rate for Adam optimizer + weight_decay: 0.00001 # weight decay for Adam for Adam optimizer + +model: + num_layer: 5 # number of graph conv layers + emb_dim: 300 # embedding dimension in graph conv layers + feat_dim: 512 # output feature dimention + dropout: 0 # dropout ratio + pool: mean # readout pooling (i.e., mean/max/add) + +dataset: + num_workers: 12 # dataloader number of workers + valid_size: 0.05 # ratio of validation data + data_path: data/pubchem-10m-clean.txt # path of pre-training data + +loss: + temperature: 0.1 # temperature of (weighted) NT-Xent loss + use_cosine_similarity: True # whether to use cosine similarity in (weighted) NT-Xent loss (i.e. True/False) + lambda_1: 0.5 # $\lambda_1$ to control faulty negative mitigation + lambda_2: 0.5 # $\lambda_2$ to control fragment contrast diff --git a/config_finetune.yaml b/config_finetune.yaml new file mode 100644 index 0000000..4d6d35d --- /dev/null +++ b/config_finetune.yaml @@ -0,0 +1,23 @@ +batch_size: 32 # batch size +epochs: 100 # total number of epochs +eval_every_n_epochs: 1 # validation frequency +fine_tune_from: pretrained # directory of pre-trained model +log_every_n_steps: 50 # print training log frequency +gpu: cuda:0 # training GPU +task_name: BBBP # name of fine-tuning benchmark, inlcuding + # classifications: BBBP/BACE/ClinTox/Tox21/HIV/SIDER/MUV + # regressions: FreeSolv/ESOL/Lipo/qm7/qm8 + +optim: + lr: 0.0005 # initial learning rate for the prediction head + weight_decay: 0.000001 # weight decay of Adam + base_ratio: 0.4 # ratio of learning rate for the base GNN encoder + +model: # notice that other 'model' variables are defined from the config of pretrained model + drop_ratio: 0.3 # dropout ratio + pool: mean # readout pooling (i.e., mean/max/add) + +dataset: + num_workers: 4 # dataloader number of workers + valid_size: 0.1 # ratio of validation data + test_size: 0.1 # ratio of test data diff --git a/data_aug/dataset.py b/data_aug/dataset.py new file mode 100644 index 0000000..1b08c60 --- /dev/null +++ b/data_aug/dataset.py @@ -0,0 +1,174 @@ +import os +import csv +import math +import time +import random +import networkx as nx +import numpy as np +from copy import deepcopy + +import torch +import torch.nn.functional as F +from torch.utils.data import Dataset, DataLoader +from torch.utils.data.sampler import SubsetRandomSampler +import torchvision.transforms as transforms + +from torch_scatter import scatter +from torch_geometric.data import Data, Batch + +import rdkit +from rdkit import Chem +from rdkit.Chem.rdchem import HybridizationType +from rdkit.Chem.rdchem import BondType as BT +from rdkit.Chem import AllChem + + +ATOM_LIST = list(range(1,119)) +CHIRALITY_LIST = [ + Chem.rdchem.ChiralType.CHI_UNSPECIFIED, + Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW, + Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW, + Chem.rdchem.ChiralType.CHI_OTHER +] +BOND_LIST = [ + BT.SINGLE, + BT.DOUBLE, + BT.TRIPLE, + BT.AROMATIC +] +BONDDIR_LIST = [ + Chem.rdchem.BondDir.NONE, + Chem.rdchem.BondDir.ENDUPRIGHT, + Chem.rdchem.BondDir.ENDDOWNRIGHT +] + + +def read_smiles(data_path): + smiles_data = [] + with open(data_path) as csv_file: + csv_reader = csv.reader(csv_file, delimiter=',') + for i, row in enumerate(csv_reader): + smiles = row[-1] + smiles_data.append(smiles) + # mol = Chem.MolFromSmiles(smiles) + # if mol != None: + # smiles_data.append(smiles) + return smiles_data + + +def removeSubgraph(Graph, center, percent=0.2): + assert percent <= 1 + G = Graph.copy() + num = int(np.floor(len(G.nodes)*percent)) + removed = [] + temp = [center] + + while len(removed) < num: + neighbors = [] + for n in temp: + neighbors.extend([i for i in G.neighbors(n) if i not in temp]) + for n in temp: + if len(removed) < num: + G.remove_node(n) + removed.append(n) + else: + break + temp = list(set(neighbors)) + return G, removed + + +class MoleculeDataset(Dataset): + def __init__(self, smiles_data): + super(Dataset, self).__init__() + self.smiles_data = smiles_data + + def __getitem__(self, index): + mol = Chem.MolFromSmiles(self.smiles_data[index]) + mol = Chem.AddHs(mol) + + N = mol.GetNumAtoms() + M = mol.GetNumBonds() + + type_idx = [] + chirality_idx = [] + atomic_number = [] + atoms = mol.GetAtoms() + bonds = mol.GetBonds() + # Sample 2 different centers to start for i and j + start_i, start_j = random.sample(list(range(N)), 2) + + # Construct the original molecular graph from edges (bonds) + edges = [] + for bond in bonds: + edges.append([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()]) + molGraph = nx.Graph(edges) + + # Get the graph for i and j after removing subgraphs + percent_i, percent_j = 0.25, 0.25 + G_i, removed_i = removeSubgraph(molGraph, start_i, percent_i) + G_j, removed_j = removeSubgraph(molGraph, start_j, percent_j) + + for atom in atoms: + type_idx.append(ATOM_LIST.index(atom.GetAtomicNum())) + chirality_idx.append(CHIRALITY_LIST.index(atom.GetChiralTag())) + atomic_number.append(atom.GetAtomicNum()) + + x1 = torch.tensor(type_idx, dtype=torch.long).view(-1,1) + x2 = torch.tensor(chirality_idx, dtype=torch.long).view(-1,1) + x = torch.cat([x1, x2], dim=-1) + # x shape (N, 2) [type, chirality] + + # Mask the atoms in the removed list + x_i = deepcopy(x) + for atom_idx in removed_i: + # Change atom type to 118, and chirality to 0 + x_i[atom_idx,:] = torch.tensor([len(ATOM_LIST), 0]) + x_j = deepcopy(x) + for atom_idx in removed_j: + # Change atom type to 118, and chirality to 0 + x_j[atom_idx,:] = torch.tensor([len(ATOM_LIST), 0]) + + # Only consider bond still exist after removing subgraph + row_i, col_i, row_j, col_j = [], [], [], [] + edge_feat_i, edge_feat_j = [], [] + G_i_edges = list(G_i.edges) + G_j_edges = list(G_j.edges) + for bond in mol.GetBonds(): + start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() + feature = [ + BOND_LIST.index(bond.GetBondType()), + BONDDIR_LIST.index(bond.GetBondDir()) + ] + if (start, end) in G_i_edges: + row_i += [start, end] + col_i += [end, start] + edge_feat_i.append(feature) + edge_feat_i.append(feature) + if (start, end) in G_j_edges: + row_j += [start, end] + col_j += [end, start] + edge_feat_j.append(feature) + edge_feat_j.append(feature) + + edge_index_i = torch.tensor([row_i, col_i], dtype=torch.long) + edge_attr_i = torch.tensor(np.array(edge_feat_i), dtype=torch.long) + edge_index_j = torch.tensor([row_j, col_j], dtype=torch.long) + edge_attr_j = torch.tensor(np.array(edge_feat_j), dtype=torch.long) + + data_i = Data(x=x_i, edge_index=edge_index_i, edge_attr=edge_attr_i) + data_j = Data(x=x_j, edge_index=edge_index_j, edge_attr=edge_attr_j) + + return data_i, data_j, mol + + def __len__(self): + return len(self.smiles_data) + + +def collate_fn(batch): + gis, gjs, mols = zip(*batch) + + gis = Batch().from_data_list(gis) + gjs = Batch().from_data_list(gjs) + + return gis, gjs, mols + diff --git a/data_aug/dataset_test.py b/data_aug/dataset_test.py new file mode 100644 index 0000000..2c1a51d --- /dev/null +++ b/data_aug/dataset_test.py @@ -0,0 +1,454 @@ +import os +import csv +import random +import numpy as np +from copy import deepcopy + +import torch +import torch.nn.functional as F +from torch.utils.data.sampler import SubsetRandomSampler +import torchvision.transforms as transforms + +from torch_scatter import scatter +from torch_geometric.data import Data, Dataset, DataLoader + +import rdkit +from rdkit import Chem +from rdkit.Chem.rdchem import HybridizationType +from rdkit.Chem.rdchem import BondType as BT +from rdkit.Chem import AllChem +from rdkit.Chem.Scaffolds.MurckoScaffold import MurckoScaffoldSmiles +from rdkit import RDLogger +RDLogger.DisableLog('rdApp.*') + + +ATOM_LIST = list(range(1,119)) +CHIRALITY_LIST = [ + Chem.rdchem.ChiralType.CHI_UNSPECIFIED, + Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW, + Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW, + Chem.rdchem.ChiralType.CHI_OTHER +] +BOND_LIST = [BT.SINGLE, BT.DOUBLE, BT.TRIPLE, BT.AROMATIC] +BONDDIR_LIST = [ + Chem.rdchem.BondDir.NONE, + Chem.rdchem.BondDir.ENDUPRIGHT, + Chem.rdchem.BondDir.ENDDOWNRIGHT +] + + +def _generate_scaffold(smiles, include_chirality=False): + mol = Chem.MolFromSmiles(smiles) + scaffold = MurckoScaffoldSmiles(mol=mol, includeChirality=include_chirality) + return scaffold + + +def generate_scaffolds(dataset, log_every_n=1000): + scaffolds = {} + data_len = len(dataset) + + print("About to generate scaffolds") + for ind, smiles in enumerate(dataset.smiles_data): + if ind % log_every_n == 0: + print("Generating scaffold %d/%d" % (ind, data_len)) + scaffold = _generate_scaffold(smiles) + if scaffold not in scaffolds: + scaffolds[scaffold] = [ind] + else: + scaffolds[scaffold].append(ind) + + # Sort from largest to smallest scaffold sets + scaffolds = {key: sorted(value) for key, value in scaffolds.items()} + scaffold_sets = [ + scaffold_set for (scaffold, scaffold_set) in sorted( + scaffolds.items(), key=lambda x: (len(x[1]), x[1][0]), reverse=True) + ] + return scaffold_sets + + +def scaffold_split(dataset, valid_size, test_size, seed=None, log_every_n=1000): + train_size = 1.0 - valid_size - test_size + scaffold_sets = generate_scaffolds(dataset) + + train_cutoff = train_size * len(dataset) + valid_cutoff = (train_size + valid_size) * len(dataset) + train_inds: List[int] = [] + valid_inds: List[int] = [] + test_inds: List[int] = [] + + print("About to sort in scaffold sets") + for scaffold_set in scaffold_sets: + if len(train_inds) + len(scaffold_set) > train_cutoff: + if len(train_inds) + len(valid_inds) + len(scaffold_set) > valid_cutoff: + test_inds += scaffold_set + else: + valid_inds += scaffold_set + else: + train_inds += scaffold_set + + print('train: {}, valid: {}, test: {}'.format( + len(train_inds), len(valid_inds), len(test_inds))) + return train_inds, valid_inds, test_inds + + +def read_smiles(data_path, target, task): + smiles_data, labels = [], [] + with open(data_path) as csv_file: + # csv_reader = csv.reader(csv_file, delimiter=',') + csv_reader = csv.DictReader(csv_file, delimiter=',') + for i, row in enumerate(csv_reader): + if i != 0: + # smiles = row[3] + smiles = row['smiles'] + label = row[target] + mol = Chem.MolFromSmiles(smiles) + if mol != None and label != '': + smiles_data.append(smiles) + if task == 'classification': + labels.append(int(label)) + elif task == 'regression': + labels.append(float(label)) + else: + ValueError('task must be either regression or classification') + print('Number of data:', len(smiles_data)) + return smiles_data, labels + + +class MolTestDataset(Dataset): + def __init__(self, data_path, target='p_np', task='classification'): + super(Dataset, self).__init__() + self.smiles_data, self.labels = read_smiles(data_path, target, task) + self.task = task + + def __getitem__(self, index): + mol = Chem.MolFromSmiles(self.smiles_data[index]) + mol = Chem.AddHs(mol) + + N = mol.GetNumAtoms() + M = mol.GetNumBonds() + + type_idx = [] + chirality_idx = [] + atomic_number = [] + for atom in mol.GetAtoms(): + type_idx.append(ATOM_LIST.index(atom.GetAtomicNum())) + chirality_idx.append(CHIRALITY_LIST.index(atom.GetChiralTag())) + atomic_number.append(atom.GetAtomicNum()) + + x1 = torch.tensor(type_idx, dtype=torch.long).view(-1,1) + x2 = torch.tensor(chirality_idx, dtype=torch.long).view(-1,1) + x = torch.cat([x1, x2], dim=-1) + + row, col, edge_feat = [], [], [] + for bond in mol.GetBonds(): + start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() + row += [start, end] + col += [end, start] + # edge_type += 2 * [MOL_BONDS[bond.GetBondType()]] + edge_feat.append([ + BOND_LIST.index(bond.GetBondType()), + BONDDIR_LIST.index(bond.GetBondDir()) + ]) + edge_feat.append([ + BOND_LIST.index(bond.GetBondType()), + BONDDIR_LIST.index(bond.GetBondDir()) + ]) + + edge_index = torch.tensor([row, col], dtype=torch.long) + edge_attr = torch.tensor(np.array(edge_feat), dtype=torch.long) + if self.task == 'classification': + y = torch.tensor(self.labels[index], dtype=torch.long).view(1,-1) + elif self.task == 'regression': + y = torch.tensor(self.labels[index], dtype=torch.float).view(1,-1) + data = Data(x=x, y=y, edge_index=edge_index, edge_attr=edge_attr) + + return data + + + def __len__(self): + return len(self.smiles_data) + + +class MolTestDatasetWrapper(object): + def __init__(self, batch_size, num_workers, valid_size, test_size, data_path, target, task): + super(object, self).__init__() + self.data_path = data_path + self.batch_size = batch_size + self.num_workers = num_workers + self.valid_size = valid_size + self.test_size = test_size + self.target = target + self.task = task + + def get_data_loaders(self): + train_dataset = MolTestDataset(data_path=self.data_path, target=self.target, task=self.task) + train_loader, valid_loader, test_loader = self.get_train_validation_data_loaders(train_dataset) + return train_loader, valid_loader, test_loader + + def get_train_validation_data_loaders(self, train_dataset): + train_idx, valid_idx, test_idx = scaffold_split(train_dataset, self.valid_size, self.test_size) + + # define samplers for obtaining training and validation batches + train_sampler = SubsetRandomSampler(train_idx) + valid_sampler = SubsetRandomSampler(valid_idx) + test_sampler = SubsetRandomSampler(test_idx) + + train_loader = DataLoader(train_dataset, batch_size=self.batch_size, sampler=train_sampler, + num_workers=self.num_workers, drop_last=False, shuffle=False) + + valid_loader = DataLoader(train_dataset, batch_size=self.batch_size, sampler=valid_sampler, + num_workers=self.num_workers, drop_last=False) + + test_loader = DataLoader(train_dataset, batch_size=self.batch_size, sampler=test_sampler, + num_workers=self.num_workers, drop_last=False) + + return train_loader, valid_loader, test_loader + + +# import os +# import csv +# import math +# import time +# import random +# import networkx as nx +# import numpy as np +# from copy import deepcopy + +# import torch +# import torch.nn.functional as F +# import torchvision.transforms as transforms + +# from torch_scatter import scatter +# from torch_geometric.data import Data, Dataset, DataLoader + +# import rdkit +# from rdkit import Chem +# from rdkit.Chem.rdchem import HybridizationType as HT +# from rdkit.Chem.rdchem import BondType as BT +# from rdkit.Chem.rdchem import BondStereo +# from rdkit.Chem import AllChem +# from rdkit.Chem.Scaffolds.MurckoScaffold import MurckoScaffoldSmiles +# from rdkit import RDLogger +# RDLogger.DisableLog('rdApp.*') + +# # node feature lists +# ATOM_LIST = list(range(0,119)) +# CHIRALITY_LIST = [ +# Chem.rdchem.ChiralType.CHI_UNSPECIFIED, +# Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW, +# Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW, +# Chem.rdchem.ChiralType.CHI_OTHER +# ] +# CHARGE_LIST = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5] +# HYBRIDIZATION_LIST = [ +# HT.S, HT.SP, HT.SP2, HT.SP3, HT.SP3D, +# HT.SP3D2, HT.UNSPECIFIED +# ] +# NUM_H_LIST = [0, 1, 2, 3, 4, 5, 6, 7, 8] +# VALENCE_LIST = [0, 1, 2, 3, 4, 5, 6] +# DEGREE_LIST = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + +# # edge feature lists +# BOND_LIST = [0, BT.SINGLE, BT.DOUBLE, BT.TRIPLE, BT.AROMATIC] +# BONDDIR_LIST = [ +# Chem.rdchem.BondDir.NONE, +# Chem.rdchem.BondDir.ENDUPRIGHT, +# Chem.rdchem.BondDir.ENDDOWNRIGHT +# ] +# STEREO_LIST = [ +# BondStereo.STEREONONE, BondStereo.STEREOANY, +# BondStereo.STEREOZ, BondStereo.STEREOE, +# BondStereo.STEREOCIS, BondStereo.STEREOTRANS +# ] + + +# def _generate_scaffold(smiles, include_chirality=False): +# mol = Chem.MolFromSmiles(smiles) +# scaffold = MurckoScaffoldSmiles(mol=mol, includeChirality=include_chirality) +# return scaffold + + +# def generate_scaffolds(smiles_data, log_every_n=1000): +# scaffolds = {} +# data_len = len(smiles_data) +# print(data_len) + +# print("About to generate scaffolds") +# for ind, smiles in enumerate(smiles_data): +# if ind % log_every_n == 0: +# print("Generating scaffold %d/%d" % (ind, data_len)) +# scaffold = _generate_scaffold(smiles) +# if scaffold not in scaffolds: +# scaffolds[scaffold] = [ind] +# else: +# scaffolds[scaffold].append(ind) + +# # Sort from largest to smallest scaffold sets +# scaffolds = {key: sorted(value) for key, value in scaffolds.items()} +# scaffold_sets = [ +# scaffold_set for (scaffold, scaffold_set) in sorted( +# scaffolds.items(), key=lambda x: (len(x[1]), x[1][0]), reverse=True) +# ] +# # print(scaffold_sets) +# return scaffold_sets + + +# def scaffold_split(smiles_data, valid_size, test_size, seed=None, log_every_n=1000): +# train_size = 1.0 - valid_size - test_size +# scaffold_sets = generate_scaffolds(smiles_data) + +# train_cutoff = train_size * len(smiles_data) +# valid_cutoff = (train_size + valid_size) * len(smiles_data) +# train_inds: List[int] = [] +# valid_inds: List[int] = [] +# test_inds: List[int] = [] + +# print("About to sort in scaffold sets") +# for scaffold_set in scaffold_sets: +# if len(train_inds) + len(scaffold_set) > train_cutoff: +# if len(train_inds) + len(valid_inds) + len(scaffold_set) > valid_cutoff: +# test_inds += scaffold_set +# else: +# valid_inds += scaffold_set +# else: +# train_inds += scaffold_set +# return train_inds, valid_inds, test_inds + + +# def read_smiles(data_path, target, task): +# smiles_data, labels = [], [] +# with open(data_path) as csv_file: +# # csv_reader = csv.reader(csv_file, delimiter=',') +# csv_reader = csv.DictReader(csv_file, delimiter=',') +# for i, row in enumerate(csv_reader): +# if i != 0: +# # smiles = row[3] +# smiles = row['smiles'] +# label = row[target] +# mol = Chem.MolFromSmiles(smiles) +# if mol != None and label != '': +# smiles_data.append(smiles) +# if task == 'classification': +# labels.append(int(label)) +# elif task == 'regression': +# labels.append(float(label)) +# else: +# ValueError('task must be either regression or classification') +# print(len(smiles_data)) +# return smiles_data, labels + + +# class MolTestDataset(Dataset): +# def __init__(self, smiles_data, labels, task): +# super(Dataset, self).__init__() +# # self.smiles_data, self.labels = read_smiles(data_path, target, task) +# # self.task = task +# self.smiles_data = smiles_data +# self.labels = labels +# self.task = task + +# def __getitem__(self, index): +# mol = Chem.MolFromSmiles(self.smiles_data[index]) +# # mol = Chem.AddHs(mol) + +# N = mol.GetNumAtoms() +# M = mol.GetNumBonds() + +# atomic = [] +# degree, charge, hybrization = [], [], [] +# aromatic, num_hs, chirality = [], [], [] +# atoms = mol.GetAtoms() +# bonds = mol.GetBonds() + +# for atom in atoms: +# atomic.append(ATOM_LIST.index(atom.GetAtomicNum())) +# degree.append(DEGREE_LIST.index(atom.GetDegree())) +# charge.append(CHARGE_LIST.index(atom.GetFormalCharge())) +# hybrization.append(HYBRIDIZATION_LIST.index(atom.GetHybridization())) +# aromatic.append(1 if atom.GetIsAromatic() else 0) +# num_hs.append(NUM_H_LIST.index(atom.GetTotalNumHs())) +# chirality.append(CHIRALITY_LIST.index(atom.GetChiralTag())) + +# atomic = F.one_hot(torch.tensor(atomic, dtype=torch.long), num_classes=len(ATOM_LIST)) +# degree = F.one_hot(torch.tensor(degree, dtype=torch.long), num_classes=len(DEGREE_LIST)) +# charge = F.one_hot(torch.tensor(charge, dtype=torch.long), num_classes=len(CHARGE_LIST)) +# hybrization = F.one_hot(torch.tensor(hybrization, dtype=torch.long), num_classes=len(HYBRIDIZATION_LIST)) +# aromatic = F.one_hot(torch.tensor(aromatic, dtype=torch.long), num_classes=2) +# num_hs = F.one_hot(torch.tensor(num_hs, dtype=torch.long), num_classes=len(NUM_H_LIST)) +# chirality = F.one_hot(torch.tensor(chirality, dtype=torch.long), num_classes=len(CHIRALITY_LIST)) +# x = torch.cat([atomic, degree, charge, hybrization, aromatic, num_hs, chirality], dim=-1).type(torch.FloatTensor) +# node_feat_dim = x.shape[1] + +# # Only consider bond still exist after removing subgraph +# row_i, col_i = [], [] +# bond_i, bond_dir_i, stereo_i = [], [], [] +# for bond in mol.GetBonds(): +# start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() +# row_i += [start, end] +# col_i += [end, start] + +# b = BOND_LIST.index(bond.GetBondType()) +# bd = BONDDIR_LIST.index(bond.GetBondDir()) +# s = STEREO_LIST.index(bond.GetStereo()) +# bond_i.append(b) +# bond_i.append(b) +# bond_dir_i.append(bd) +# bond_dir_i.append(bd) +# stereo_i.append(s) +# stereo_i.append(s) + +# edge_index_i = torch.tensor([row_i, col_i], dtype=torch.long) + +# bond_i = F.one_hot(torch.tensor(bond_i, dtype=torch.long), num_classes=len(BOND_LIST)) +# bond_dir_i = F.one_hot(torch.tensor(bond_dir_i, dtype=torch.long), num_classes=len(BONDDIR_LIST)) +# stereo_i = F.one_hot(torch.tensor(stereo_i, dtype=torch.long), num_classes=len(STEREO_LIST)) +# edge_attr_i = torch.cat([bond_i, bond_dir_i, stereo_i], dim=-1).type(torch.FloatTensor) + +# if self.task == 'classification': +# y = torch.tensor(self.labels[index], dtype=torch.long).view(1,-1) +# elif self.task == 'regression': +# y = torch.tensor(self.labels[index], dtype=torch.float).view(1,-1) +# data = Data(x=x, y=y, edge_index=edge_index_i, edge_attr=edge_attr_i) + +# return data + +# def __len__(self): +# return len(self.smiles_data) + + +# class MolTestDatasetWrapper(object): +# def __init__(self, batch_size, num_workers, valid_size, test_size, data_path, target, task): +# super(object, self).__init__() +# self.data_path = data_path +# self.batch_size = batch_size +# self.num_workers = num_workers +# self.valid_size = valid_size +# self.test_size = test_size +# self.target = target +# self.task = task +# self.smiles_data, self.labels = read_smiles(data_path, target, task) +# self.smiles_data = np.asarray(self.smiles_data) +# self.labels = np.asarray(self.labels) + +# def get_data_loaders(self): +# train_idx, valid_idx, test_idx = scaffold_split(self.smiles_data, self.valid_size, self.test_size) + +# # define dataset +# train_set = MolTestDataset(self.smiles_data[train_idx], self.labels[train_idx], task=self.task) +# valid_set = MolTestDataset(self.smiles_data[valid_idx], self.labels[valid_idx], task=self.task) +# test_set = MolTestDataset(self.smiles_data[test_idx], self.labels[test_idx], task=self.task) + +# train_loader = DataLoader( +# train_set, batch_size=self.batch_size, +# num_workers=self.num_workers, drop_last=True, shuffle=True +# ) +# valid_loader = DataLoader( +# valid_set, batch_size=self.batch_size, +# num_workers=self.num_workers, drop_last=False +# ) +# test_loader = DataLoader( +# test_set, batch_size=self.batch_size, +# num_workers=self.num_workers, drop_last=False +# ) + +# return train_loader, valid_loader, test_loader diff --git a/finetune.py b/finetune.py new file mode 100644 index 0000000..dbf3b26 --- /dev/null +++ b/finetune.py @@ -0,0 +1,454 @@ +import os +import shutil +import yaml +import torch +import pandas as pd +import numpy as np +from datetime import datetime + +from torch import nn +import torch.nn.functional as F +from sklearn.metrics import mean_squared_error, mean_absolute_error +from sklearn.metrics import roc_auc_score + +from data_aug.dataset_test import MolTestDatasetWrapper +from models.ginet_finetune import GINet + + +def _save_config_file(log_dir, config): + if not os.path.exists(log_dir): + os.makedirs(log_dir) + with open(os.path.join(log_dir, 'config_finetune.yaml'), 'w') as config_file: + yaml.dump(config, config_file) + + +class Normalizer(object): + """Normalize a Tensor and restore it later. """ + + def __init__(self, tensor): + """tensor is taken as a sample to calculate the mean and std""" + self.mean = torch.mean(tensor) + self.std = torch.std(tensor) + + def norm(self, tensor): + return (tensor - self.mean) / self.std + + def denorm(self, normed_tensor): + return normed_tensor * self.std + self.mean + + def state_dict(self): + return {'mean': self.mean, + 'std': self.std} + + def load_state_dict(self, state_dict): + self.mean = state_dict['mean'] + self.std = state_dict['std'] + + +class FineTune(object): + def __init__(self, dataset, config): + self.config = config + self.device = self._get_device() + self.dataset = dataset + + current_time = datetime.now().strftime('%b%d_%H-%M-%S') + dir_name = config['fine_tune_from'].split('/')[0] + '-' + \ + config['fine_tune_from'].split('/')[-1] + '-' + config['task_name'] + subdir_name = current_time + '-' + config['dataset']['target'] + self.log_dir = os.path.join('experiments', dir_name, subdir_name) + + model_yaml_dir = os.path.join(config['fine_tune_from'], 'checkpoints') + for fn in os.listdir(model_yaml_dir): + if fn.endswith(".yaml"): + model_yaml_fn = fn + break + model_yaml = os.path.join(model_yaml_dir, model_yaml_fn) + model_config = yaml.load(open(model_yaml, "r"), Loader=yaml.FullLoader) + self.model_config = model_config['model'] + self.model_config['dropout'] = self.config['model']['dropout'] + self.model_config['pool'] = self.config['model']['pool'] + + if config['dataset']['task'] == 'classification': + self.criterion = nn.CrossEntropyLoss() + elif config['dataset']['task'] == 'regression': + if self.config["task_name"] in ['qm7', 'qm8']: + # self.criterion = nn.L1Loss() + self.criterion = nn.SmoothL1Loss() + else: + self.criterion = nn.MSELoss() + + # save config file + _save_config_file(self.log_dir, self.config) + + def _get_device(self): + if torch.cuda.is_available() and self.config['gpu'] != 'cpu': + device = self.config['gpu'] + torch.cuda.set_device(device) + else: + device = 'cpu' + print("Running on:", device) + + return device + + def _step(self, model, data): + pred = model(data) + + if self.config['dataset']['task'] == 'classification': + loss = self.criterion(pred, data.y.view(-1)) + elif self.config['dataset']['task'] == 'regression': + # loss = self.criterion(pred, data.y) + if self.normalizer: + loss = self.criterion(pred, self.normalizer.norm(data.y)) + else: + loss = self.criterion(pred, data.y) + + return loss + + def train(self): + train_loader, valid_loader, test_loader = self.dataset.get_data_loaders() + + self.normalizer = None + if self.config["task_name"] in ['qm7']: + labels = [] + for d in train_loader: + labels.append(d.y) + labels = torch.cat(labels) + self.normalizer = Normalizer(labels) + print(self.normalizer.mean, self.normalizer.std, labels.shape) + + n_batches = len(train_loader) + if n_batches < self.config['log_every_n_steps']: + self.config['log_every_n_steps'] = n_batches + + model = GINet(self.config['dataset']['task'], **self.model_config).to(self.device) + model = self._load_pre_trained_weights(model) + + layer_list = [] + for name, param in model.named_parameters(): + if 'output_layers' in name: + print(name) + layer_list.append(name) + + params = list(map(lambda x: x[1],list(filter(lambda kv: kv[0] in layer_list, model.named_parameters())))) + base_params = list(map(lambda x: x[1],list(filter(lambda kv: kv[0] not in layer_list, model.named_parameters())))) + + if self.config['optim']['type'] == 'SGD': + init_lr = self.config['optim']['base_lr'] * self.config['batch_size'] / 256 + optimizer = torch.optim.SGD( + [ {'params': params, 'lr': init_lr}, + {'params': base_params, 'lr': init_lr * self.config['optim']['base_ratio']} + ], + momentum=self.config['optim']['momentum'], + weight_decay=self.config['optim']['weight_decay'] + ) + elif self.config['optim']['type'] == 'Adam': + optimizer = torch.optim.Adam( + [ {'params': params, 'lr': self.config['optim']['lr']}, + {'params': base_params, 'lr': self.config['optim']['lr'] * self.config['optim']['base_ratio']} + ], + weight_decay=self.config['optim']['weight_decay'] + ) + else: + raise ValueError('Not defined optimizer type!') + + n_iter = 0 + valid_n_iter = 0 + best_valid_loss = np.inf + best_valid_rmse = np.inf + best_valid_mae = np.inf + best_valid_roc_auc = 0 + + for epoch_counter in range(self.config['epochs']): + for bn, data in enumerate(train_loader): + data = data.to(self.device) + loss = self._step(model, data) + + if n_iter % self.config['log_every_n_steps'] == 0: + print(epoch_counter, bn, loss.item()) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + n_iter += 1 + + # validate the model if requested + if epoch_counter % self.config['eval_every_n_epochs'] == 0: + if self.config['dataset']['task'] == 'classification': + valid_loss, valid_roc_auc = self._validate(model, valid_loader) + if valid_roc_auc > best_valid_roc_auc: + best_valid_roc_auc = valid_roc_auc + # save the model weights + torch.save(model.state_dict(), os.path.join(self.log_dir, 'model.pth')) + elif self.config['dataset']['task'] == 'regression': + valid_loss, valid_rmse, valid_mae = self._validate(model, valid_loader) + if self.config["task_name"] in ['qm7', 'qm8'] and valid_mae < best_valid_mae: + best_valid_mae = valid_mae + # save the model weights + torch.save(model.state_dict(), os.path.join(self.log_dir, 'model.pth')) + elif valid_rmse < best_valid_rmse: + best_valid_rmse = valid_rmse + # save the model weights + torch.save(model.state_dict(), os.path.join(self.log_dir, 'model.pth')) + + valid_n_iter += 1 + + return self._test(model, test_loader) + + def _load_pre_trained_weights(self, model): + try: + checkpoints_folder = os.path.join(self.config['fine_tune_from'], 'checkpoints') + ckp_path = os.path.join(checkpoints_folder, 'model.pth') + state_dict = torch.load(ckp_path, map_location=self.device) + model.load_my_state_dict(state_dict) + print("Loaded pre-trained model {} with success.".format(ckp_path)) + + except FileNotFoundError: + print("Pre-trained weights not found. Training from scratch.") + + return model + + def _validate(self, model, valid_loader): + # test steps + predictions = [] + labels = [] + with torch.no_grad(): + model.eval() + + valid_loss = 0.0 + num_data = 0 + for bn, data in enumerate(valid_loader): + data = data.to(self.device) + + pred = model(data) + loss = self._step(model, data) + + valid_loss += loss.item() * data.y.size(0) + num_data += data.y.size(0) + + if self.normalizer: + pred = self.normalizer.denorm(pred) + + if self.config['dataset']['task'] == 'classification': + pred = F.softmax(pred, dim=-1) + + if self.device == 'cpu': + predictions.extend(pred.detach().numpy()) + labels.extend(data.y.flatten().numpy()) + else: + predictions.extend(pred.cpu().detach().numpy()) + labels.extend(data.y.cpu().flatten().numpy()) + + valid_loss /= num_data + + model.train() + + if self.config['dataset']['task'] == 'regression': + predictions = np.array(predictions) + labels = np.array(labels) + rmse = mean_squared_error(labels, predictions, squared=False) + mae = mean_absolute_error(labels, predictions) + print('Validation loss:', valid_loss, 'RMSE:', rmse, 'MAE:', mae) + return valid_loss, rmse, mae + + elif self.config['dataset']['task'] == 'classification': + predictions = np.array(predictions) + labels = np.array(labels) + roc_auc = roc_auc_score(labels, predictions[:,1]) + print('Validation loss:', valid_loss, 'ROC AUC:', roc_auc) + return valid_loss, roc_auc + + def _test(self, model, test_loader): + model_path = os.path.join(self.log_dir, 'model.pth') + state_dict = torch.load(model_path, map_location=self.device) + model.load_state_dict(state_dict) + print("Loaded {} with success.".format(model_path)) + + # test steps + predictions = [] + labels = [] + with torch.no_grad(): + model.eval() + + test_loss = 0.0 + num_data = 0 + for bn, data in enumerate(test_loader): + data = data.to(self.device) + + pred = model(data) + loss = self._step(model, data) + + test_loss += loss.item() * data.y.size(0) + num_data += data.y.size(0) + + if self.normalizer: + pred = self.normalizer.denorm(pred) + + if self.config['dataset']['task'] == 'classification': + pred = F.softmax(pred, dim=-1) + + if self.device == 'cpu': + predictions.extend(pred.detach().numpy()) + labels.extend(data.y.flatten().numpy()) + else: + predictions.extend(pred.cpu().detach().numpy()) + labels.extend(data.y.cpu().flatten().numpy()) + + test_loss /= num_data + + model.train() + + if self.config['dataset']['task'] == 'regression': + predictions = np.array(predictions) + labels = np.array(labels) + rmse = mean_squared_error(labels, predictions, squared=False) + mae = mean_absolute_error(labels, predictions) + print('Test loss:', test_loss, 'RMSE:', rmse, 'MAE:', mae) + return test_loss, rmse, mae + + elif self.config['dataset']['task'] == 'classification': + predictions = np.array(predictions) + labels = np.array(labels) + roc_auc = roc_auc_score(labels, predictions[:,1]) + print('Test loss:', test_loss, 'ROC AUC:', roc_auc) + return test_loss, roc_auc + + +def run(config): + dataset = MolTestDatasetWrapper(config['batch_size'], **config['dataset']) + fine_tune = FineTune(dataset, config) + return fine_tune.train() + + +def get_config(): + config = yaml.load(open("config_finetune.yaml", "r"), Loader=yaml.FullLoader) + + if config['task_name'] == 'BBBP': + config['dataset']['task'] = 'classification' + config['dataset']['data_path'] = './data/bbbp/raw/BBBP.csv' + target_list = ["p_np"] + + elif config['task_name'] == 'Tox21': + config['dataset']['task'] = 'classification' + config['dataset']['data_path'] = './data/tox21/raw/tox21.csv' + target_list = [ + "NR-AR", "NR-AR-LBD", "NR-AhR", "NR-Aromatase", "NR-ER", "NR-ER-LBD", + "NR-PPAR-gamma", "SR-ARE", "SR-ATAD5", "SR-HSE", "SR-MMP", "SR-p53" + ] + + elif config['task_name'] == 'ClinTox': + config['dataset']['task'] = 'classification' + config['dataset']['data_path'] = './data/clintox/raw/clintox.csv' + target_list = ['CT_TOX', 'FDA_APPROVED'] + + elif config['task_name'] == 'HIV': + config['dataset']['task'] = 'classification' + config['dataset']['data_path'] = './data/hiv/raw/HIV.csv' + target_list = ["HIV_active"] + + elif config['task_name'] == 'BACE': + config['dataset']['task'] = 'classification' + config['dataset']['data_path'] = './data/bace/raw/bace.csv' + target_list = ["Class"] + + elif config['task_name'] == 'SIDER': + config['dataset']['task'] = 'classification' + config['dataset']['data_path'] = './data/sider/raw/sider.csv' + target_list = [ + "Hepatobiliary disorders", "Metabolism and nutrition disorders", "Product issues", "Eye disorders", "Investigations", + "Musculoskeletal and connective tissue disorders", "Gastrointestinal disorders", "Social circumstances", + "Immune system disorders", "Reproductive system and breast disorders", + "Neoplasms benign, malignant and unspecified (incl cysts and polyps)", + "General disorders and administration site conditions", + "Endocrine disorders", "Surgical and medical procedures", "Vascular disorders", "Blood and lymphatic system disorders", + "Skin and subcutaneous tissue disorders", "Congenital, familial and genetic disorders", "Infections and infestations", + "Respiratory, thoracic and mediastinal disorders", "Psychiatric disorders", "Renal and urinary disorders", + "Pregnancy, puerperium and perinatal conditions", "Ear and labyrinth disorders", "Cardiac disorders", + "Nervous system disorders", "Injury, poisoning and procedural complications" + ] + + elif config['task_name'] == 'MUV': + config['dataset']['task'] = 'classification' + config['dataset']['data_path'] = './data/muv/raw/muv.csv' + target_list = [ + "MUV-466", "MUV-548", "MUV-600", "MUV-644", "MUV-652", "MUV-692", "MUV-712", "MUV-713", + "MUV-733", "MUV-737", "MUV-810", "MUV-832", "MUV-846", "MUV-852", "MUV-858", "MUV-859" + ] + + elif config['task_name'] == 'FreeSolv': + config['dataset']['task'] = 'regression' + config['dataset']['data_path'] = './data/freesolv/raw/SAMPL.csv' + target_list = ["expt"] + + elif config["task_name"] == 'ESOL': + config['dataset']['task'] = 'regression' + config['dataset']['data_path'] = './data/esol/raw/delaney-processed.csv' + target_list = ["measured log solubility in mols per litre"] + + elif config["task_name"] == 'Lipo': + config['dataset']['task'] = 'regression' + config['dataset']['data_path'] = './data/lipophilicity/raw/Lipophilicity.csv' + target_list = ["exp"] + + elif config["task_name"] == 'qm7': + config['dataset']['task'] = 'regression' + config['dataset']['data_path'] = './data/qm7/qm7.csv' + target_list = ["u0_atom"] + + elif config["task_name"] == 'qm8': + config['dataset']['task'] = 'regression' + config['dataset']['data_path'] = './data/qm8/qm8.csv' + target_list = [ + "E1-CC2", "E2-CC2", "f1-CC2", "f2-CC2", "E1-PBE0", "E2-PBE0", "f1-PBE0", "f2-PBE0", + "E1-CAM", "E2-CAM", "f1-CAM","f2-CAM" + ] + + else: + raise ValueError('Unspecified dataset!') + + print(config) + return config, target_list + + +if __name__ == '__main__': + config, target_list = get_config() + + os.makedirs('experiments', exist_ok=True) + dir_name = config['fine_tune_from'].split('/')[0] + '-' + \ + config['fine_tune_from'].split('/')[-1] + '-' + config['task_name'] + save_dir = os.path.join('experiments', dir_name) + + current_time = datetime.now().strftime('%b%d_%H-%M-%S') + + if config['dataset']['task'] == 'classification': + save_list = [] + for target in target_list: + config['dataset']['target'] = target + roc_list = [target] + test_loss, roc_auc = run(config) + roc_list.append(roc_auc) + save_list.append(roc_list) + + df = pd.DataFrame(save_list) + fn = '{}_{}_ROC.csv'.format(config["task_name"], current_time) + df.to_csv(os.path.join(save_dir, fn), index=False, header=['label', 'ROC-AUC']) + + elif config['dataset']['task'] == 'regression': + save_rmse_list, save_mae_list = [], [] + for target in target_list: + config['dataset']['target'] = target + rmse_list, mae_list = [target], [target] + test_loss, rmse, mae = run(config) + rmse_list.append(rmse) + mae_list.append(mae) + + save_rmse_list.append(rmse_list) + save_mae_list.append(mae_list) + + df = pd.DataFrame(save_rmse_list) + fn = '{}_{}_RMSE.csv'.format(config["task_name"], current_time) + df.to_csv(os.path.join(save_dir, fn), index=False, header=['label', 'RMSE']) + + df = pd.DataFrame(save_mae_list) + fn = '{}_{}_MAE.csv'.format(config["task_name"], current_time) + df.to_csv(os.path.join(save_dir, fn), index=False, header=['label', 'MAE']) \ No newline at end of file diff --git a/imolclr.py b/imolclr.py new file mode 100644 index 0000000..088a806 --- /dev/null +++ b/imolclr.py @@ -0,0 +1,265 @@ +import os +import shutil +import builtins +import yaml +import numpy as np +from copy import deepcopy +from datetime import datetime + +import torch +from torch import nn +import torch.nn.functional as F +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +from torch.optim.lr_scheduler import CosineAnnealingLR + +import torch.nn.parallel +import torch.backends.cudnn as cudnn +import torch.distributed as dist +import torch.multiprocessing as mp +from torch.utils.data.distributed import DistributedSampler + +from utils.nt_xent import NTXentLoss +from utils.weighted_nt_xent import WeightedNTXentLoss +from models.ginet import GINet +from data_aug.dataset import read_smiles, collate_fn, MoleculeDataset + + +def _save_config_file(model_checkpoints_folder): + if not os.path.exists(model_checkpoints_folder): + os.makedirs(model_checkpoints_folder) + shutil.copy('./config.yaml', os.path.join(model_checkpoints_folder, 'config.yaml')) + + +def get_dataset(batch_size, num_workers, valid_size, data_path): + data_path = data_path + batch_size = batch_size + num_workers = num_workers + valid_size = valid_size + + smiles_data = read_smiles(data_path) + + # obtain training indices that will be used for validation + num_train = len(smiles_data) + indices = list(range(num_train)) + + np.random.shuffle(indices) + + split = int(np.floor(valid_size * num_train)) + train_idx, valid_idx = indices[split:], indices[:split] + + train_smiles = [smiles_data[i] for i in train_idx] + valid_smiles = [smiles_data[i] for i in valid_idx] + + del smiles_data + + train_dataset = MoleculeDataset(train_smiles) + valid_dataset = MoleculeDataset(valid_smiles) + + return train_dataset, valid_dataset + + +def main(): + config = yaml.load(open("config.yaml", "r"), Loader=yaml.FullLoader) + mp.spawn(main_worker, nprocs=config['world_size'], args=(config['world_size'], config)) + + +def main_worker(rank, world_size, config): + gpu = deepcopy(rank) + print("Use GPU: {} for training".format(gpu)) + + if rank == 0: + dir_name = datetime.now().strftime('%b%d_%H-%M-%S') + log_dir = os.path.join('runs', dir_name) + log_writer = SummaryWriter(log_dir=log_dir) + model_checkpoints_folder = os.path.join(log_writer.log_dir, 'checkpoints') + _save_config_file(model_checkpoints_folder) + else: + def print_pass(*args): + pass + builtins.print = print_pass + + dist.init_process_group( + backend=config['backend'], world_size=world_size, rank=rank) + torch.distributed.barrier() + + model = GINet(**config["model"]) + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + + # For multiprocessing distributed, DistributedDataParallel constructor + # should always set the single device scope, otherwise, + # DistributedDataParallel will use all available devices. + torch.cuda.set_device(gpu) + + # clean up the cache in GPU + torch.cuda.empty_cache() + + model.cuda(gpu) + # When using a single GPU per process and per + # DistributedDataParallel, we need to divide the batch size + # ourselves based on the total number of GPUs we have + config['batch_size'] = int(config['batch_size'] / world_size) + config['dataset']['num_workers'] = int((config['dataset']['num_workers'] + world_size - 1) / world_size) + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu]) + + nt_xent_criterion = NTXentLoss(gpu, **config['loss']) + weighted_nt_xent_criterion = WeightedNTXentLoss(gpu, **config['loss']) + + optimizer = torch.optim.Adam( + model.parameters(), lr=config['optim']['lr'], + weight_decay=config['optim']['weight_decay'], + ) + + start_epoch = 0 + if config['resume_from'] == 'None': + print("=> train from scratch, no resume checkpoint") + else: + if os.path.isfile(config['resume_from']): + print("=> loading checkpoint '{}'".format(config['resume_from'])) + # Map model to be loaded to specified single gpu. + loc = 'cuda:{}'.format(gpu) + checkpoint = torch.load(config['resume_from'], map_location=loc) + start_epoch = checkpoint['epoch'] + model.load_state_dict(checkpoint['state_dict']) + optimizer.load_state_dict(checkpoint['optimizer']) + print("=> loaded checkpoint '{}' (epoch {})" + .format(config['resume_from'], checkpoint['epoch'])) + else: + print("=> no checkpoint found at '{}'".format(config['resume_from'])) + + cudnn.benchmark = True + + train_dataset, valid_dataset = get_dataset(config['batch_size'], **config['dataset']) + + # define samplers for obtaining training batches + train_sampler = DistributedSampler(train_dataset, shuffle=True) + + train_loader = DataLoader( + train_dataset, batch_size=config['batch_size'], sampler=train_sampler, + num_workers=config['dataset']['num_workers'], drop_last=True, + pin_memory=True, collate_fn=collate_fn + ) + + # validation loader on a single GPU + valid_loader = DataLoader( + valid_dataset, batch_size=256, shuffle=True, + num_workers=config['dataset']['num_workers'], drop_last=True, + pin_memory=True, collate_fn=collate_fn + ) + + scheduler = CosineAnnealingLR(optimizer, + T_max=len(train_loader)-config['warmup']+1, eta_min=0, last_epoch=-1 + ) + + n_iter = 0 + valid_n_iter = 0 + best_valid_loss = np.inf + + for epoch_counter in range(start_epoch, config['epochs']): + + train_sampler.set_epoch(epoch_counter) + + for bn, (g1, g2, mols) in enumerate(train_loader): + g1 = g1.cuda(gpu, non_blocking=True) + g2 = g2.cuda(gpu, non_blocking=True) + + # get the representations and the projections + __, z1_global, z1_sub = model(g1) # [N,C] + __, z2_global, z2_sub = model(g2) # [N,C] + + # normalize projection feature vectors + z1_global = F.normalize(z1_global, dim=1) + z2_global = F.normalize(z2_global, dim=1) + loss_global = weighted_nt_xent_criterion(z1_global, z2_global, mols) + + # normalize projection feature vectors + z1_sub = F.normalize(z1_sub, dim=1) + z2_sub = F.normalize(z2_sub, dim=1) + loss_sub = nt_xent_criterion(z1_sub, z2_sub) + + loss = loss_global + config['loss']['lambda_2'] * loss_sub + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + if rank == 0 and n_iter % config['log_every_n_steps'] == 0: + log_writer.add_scalar('train_loss', loss, global_step=n_iter) + log_writer.add_scalar('train_loss_global', loss_global, global_step=n_iter) + log_writer.add_scalar('train_loss_sub', loss_sub, global_step=n_iter) + log_writer.add_scalar('cosine_lr_decay', scheduler.get_last_lr()[0], global_step=n_iter) + print(epoch_counter, bn, loss_global.item(), loss_sub.item(), loss.item()) + + n_iter += 1 + + if rank == 0: + valid_loss_global, valid_loss_sub = validate( + gpu, valid_loader, [weighted_nt_xent_criterion, nt_xent_criterion], model + ) + valid_loss = valid_loss_global + config['loss']['lambda_2'] * valid_loss_sub + print('Valid |', epoch_counter, valid_loss_global, valid_loss_sub, valid_loss) + if valid_loss < best_valid_loss: + # save the best model weights + best_valid_loss = valid_loss + torch.save(model.state_dict(), os.path.join(model_checkpoints_folder, 'model.pth')) + + # save the model weights at each epoch + torch.save(model.state_dict(), os.path.join(model_checkpoints_folder, 'model_{}.pth'.format(str(epoch_counter)))) + + log_writer.add_scalar('validation_loss', valid_loss, global_step=valid_n_iter) + log_writer.add_scalar('validation_loss_global', valid_loss_global, global_step=valid_n_iter) + log_writer.add_scalar('validation_loss_sub', valid_loss_sub, global_step=valid_n_iter) + + valid_n_iter += 1 + + # warmup for the first few epochs + if epoch_counter >= config['warmup'] - 1: + scheduler.step() + + +def validate(gpu, valid_loader, criterion, model): + model.eval() + + global_criterion, sub_criterion = criterion + # valid_sampler.set_epoch(0) + + valid_loss_global, valid_loss_sub = 0, 0 + counter = 0 + + for bn, (g1, g2, mols) in enumerate(valid_loader): + g1 = g1.cuda(gpu, non_blocking=True) + g2 = g2.cuda(gpu, non_blocking=True) + + # get the representations and the projections + __, z1_global, z1_sub = model(g1) # [N,C] + __, z2_global, z2_sub = model(g2) # [N,C] + + # normalize projection feature vectors + z1_global = F.normalize(z1_global, dim=1) + z2_global = F.normalize(z2_global, dim=1) + loss_global = global_criterion(z1_global, z2_global, mols) + + # normalize projection feature vectors + z1_sub = F.normalize(z1_sub, dim=1) + z2_sub = F.normalize(z2_sub, dim=1) + loss_sub = sub_criterion(z1_sub, z2_sub) + + valid_loss_global += loss_global.item() + valid_loss_sub += loss_sub.item() + + if counter % 1 == 0: + print('validation bn:', counter) + + counter += 1 + + valid_loss_global /= counter + valid_loss_sub /= counter + + model.train() + + return valid_loss_global, valid_loss_sub + + +if __name__ == '__main__': + main() + diff --git a/models/ginet.py b/models/ginet.py new file mode 100644 index 0000000..ad1de4c --- /dev/null +++ b/models/ginet.py @@ -0,0 +1,124 @@ +import torch +from torch import nn +import torch.nn.functional as F +from torch.nn import Linear, LayerNorm, ReLU + +from torch_scatter import scatter +from torch_geometric.nn import MessagePassing +from torch_geometric.utils import add_self_loops +from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool + +import rdkit +from rdkit import Chem +from rdkit.Chem.rdchem import HybridizationType as HT +from rdkit.Chem.rdchem import BondType as BT +from rdkit.Chem.rdchem import BondStereo +from rdkit.Chem import AllChem + + +num_atom_type = 119 # including the extra mask tokens +num_chirality_tag = 3 + +num_bond_type = 5 # including aromatic and self-loop edge +num_bond_direction = 3 + + +class GINEConv(MessagePassing): + def __init__(self, emb_dim, aggr="add"): + super(GINEConv, self).__init__() + self.mlp = nn.Sequential( + nn.Linear(emb_dim, 2*emb_dim), + nn.ReLU(), + nn.Linear(2*emb_dim, emb_dim) + ) + + self.edge_embedding1 = nn.Embedding(num_bond_type, emb_dim) + self.edge_embedding2 = nn.Embedding(num_bond_direction, emb_dim) + nn.init.xavier_uniform_(self.edge_embedding1.weight.data) + nn.init.xavier_uniform_(self.edge_embedding2.weight.data) + self.aggr = aggr + + def forward(self, x, edge_index, edge_attr): + # add self loops in the edge space + edge_index = add_self_loops(edge_index, num_nodes=x.size(0))[0] + + # add features corresponding to self-loop edges. + self_loop_attr = torch.zeros(x.size(0), 2) + self_loop_attr[:,0] = 4 #bond type for self-loop edge + self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype) + edge_attr = torch.cat((edge_attr, self_loop_attr), dim=0) + + edge_embeddings = self.edge_embedding1(edge_attr[:,0]) + self.edge_embedding2(edge_attr[:,1]) + + return self.propagate(edge_index, x=x, edge_attr=edge_embeddings) + + def message(self, x_j, edge_attr): + return x_j + edge_attr + + def update(self, aggr_out): + return self.mlp(aggr_out) + + +class GINet(nn.Module): + def __init__(self, num_layer=5, emb_dim=300, feat_dim=256, dropout=0, pool='mean'): + super(GINet, self).__init__() + self.num_layer = num_layer + self.emb_dim = emb_dim + self.feat_dim = feat_dim + self.dropout = dropout + + self.x_embedding1 = nn.Embedding(num_atom_type, emb_dim) + self.x_embedding2 = nn.Embedding(num_chirality_tag, emb_dim) + nn.init.xavier_uniform_(self.x_embedding1.weight.data) + nn.init.xavier_uniform_(self.x_embedding2.weight.data) + + # List of MLPs + self.gnns = nn.ModuleList() + for layer in range(num_layer): + self.gnns.append(GINEConv(emb_dim, aggr="add")) + + # List of batchnorms + self.batch_norms = nn.ModuleList() + for layer in range(num_layer): + self.batch_norms.append(nn.BatchNorm1d(emb_dim)) + + if pool == 'mean': + self.pool = global_mean_pool + elif pool == 'max': + self.pool = global_max_pool + elif pool == 'add': + self.pool = global_add_pool + + self.feat_lin = nn.Linear(self.emb_dim, self.feat_dim) + + self.out_lin = nn.Sequential( + nn.Linear(self.feat_dim, self.feat_dim), + nn.ReLU(inplace=True), + nn.Linear(self.feat_dim, self.feat_dim//2) + ) + + def forward(self, data): + h = self.x_embedding1(data.x[:,0]) + self.x_embedding2(data.x[:,1]) + + for layer in range(self.num_layer): + h = self.gnns[layer](h, data.edge_index, data.edge_attr) + h = self.batch_norms[layer](h) + if layer == self.num_layer - 1: + h = F.dropout(h, self.dropout, training=self.training) + else: + h = F.dropout(F.relu(h), self.dropout, training=self.training) + + h_global = self.pool(h, data.batch) + h_global = self.feat_lin(h_global) + out_global = self.out_lin(h_global) + + h_sub = self.pool(h, data.motif_batch)[1:,:] + h_sub = self.feat_lin(h_sub) + out_sub = self.out_lin(h_sub) + + return h_global, out_global, out_sub + + +if __name__ == "__main__": + model = GINConv() + print(model) \ No newline at end of file diff --git a/models/ginet_finetune.py b/models/ginet_finetune.py new file mode 100644 index 0000000..79443ba --- /dev/null +++ b/models/ginet_finetune.py @@ -0,0 +1,145 @@ +import torch +from torch import nn +import torch.nn.functional as F +from torch.nn import Linear, LayerNorm, ReLU + +from torch_scatter import scatter +from torch_geometric.nn import MessagePassing +from torch_geometric.utils import add_self_loops, degree, softmax +from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool + + +num_atom_type = 119 # including the extra mask tokens +num_chirality_tag = 4 + +num_bond_type = 5 # including aromatic and self-loop edge +num_bond_direction = 3 + + +class GINEConv(MessagePassing): + def __init__(self, embed_dim, aggr="add"): + super(GINEConv, self).__init__() + self.mlp = nn.Sequential( + nn.Linear(embed_dim, 2*embed_dim), + nn.ReLU(), + nn.Linear(2*embed_dim, embed_dim) + ) + self.edge_embedding1 = nn.Embedding(num_bond_type, embed_dim) + self.edge_embedding2 = nn.Embedding(num_bond_direction, embed_dim) + + nn.init.xavier_uniform_(self.edge_embedding1.weight.data) + nn.init.xavier_uniform_(self.edge_embedding2.weight.data) + self.aggr = aggr + + def forward(self, x, edge_index, edge_attr): + # add self loops in the edge space + edge_index = add_self_loops(edge_index, num_nodes=x.size(0))[0] + + # add features corresponding to self-loop edges. + self_loop_attr = torch.zeros(x.size(0), 2) + self_loop_attr[:,0] = 4 #bond type for self-loop edge + self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype) + edge_attr = torch.cat((edge_attr, self_loop_attr), dim=0) + + edge_embeddings = self.edge_embedding1(edge_attr[:,0]) + self.edge_embedding2(edge_attr[:,1]) + + return self.propagate(edge_index, x=x, edge_attr=edge_embeddings) + + def message(self, x_j, edge_attr): + return x_j + edge_attr + + def update(self, aggr_out): + return self.mlp(aggr_out) + + +class GINet(nn.Module): + def __init__(self, task='classification', num_layer=5, embed_dim=256, dropout=0, pooling='mean'): + super(GINet, self).__init__() + self.task = task + self.num_layer = num_layer + self.embed_dim = embed_dim + self.dropout = dropout + + self.x_embedding1 = nn.Embedding(num_atom_type, embed_dim) + self.x_embedding2 = nn.Embedding(num_chirality_tag, embed_dim) + + nn.init.xavier_uniform_(self.x_embedding1.weight.data) + nn.init.xavier_uniform_(self.x_embedding2.weight.data) + + # List of MLPs + self.gnns = nn.ModuleList() + for layer in range(num_layer): + self.gnns.append(GINEConv(embed_dim, aggr="add")) + + # List of batchnorms + self.batch_norms = nn.ModuleList() + for layer in range(num_layer): + self.batch_norms.append(nn.BatchNorm1d(embed_dim)) + + if pooling == 'mean': + self.pool = global_mean_pool + elif pooling == 'max': + self.pool = global_max_pool + elif pooling == 'add': + self.pool = global_add_pool + else: + raise ValueError('Pooling operation not defined!') + + # projection head + self.proj_head = nn.Sequential( + nn.Linear(embed_dim, embed_dim, bias=False), + nn.BatchNorm1d(embed_dim), + nn.ReLU(inplace=True), # first layer + nn.Linear(embed_dim, embed_dim, bias=False), + nn.BatchNorm1d(embed_dim), + nn.ReLU(inplace=True), # second layer + nn.Linear(embed_dim, embed_dim, bias=False), + nn.BatchNorm1d(embed_dim) + ) + + # fine-tune prediction layers + if self.task == 'classification': + self.output_layers = nn.Sequential( + nn.Linear(embed_dim, embed_dim//2), + nn.Softplus(), + nn.Linear(embed_dim//2, 2) + ) + elif self.task == 'regression': + self.output_layers = nn.Sequential( + nn.Linear(embed_dim, embed_dim//2), + nn.Softplus(), + nn.Linear(embed_dim//2, 1) + ) + else: + raise ValueError('Undefined task type!') + + def forward(self, data): + h = self.x_embedding1(data.x[:,0]) + self.x_embedding2(data.x[:,1]) + + for layer in range(self.num_layer): + h = self.gnns[layer](h, data.edge_index, data.edge_attr) + h = self.batch_norms[layer](h) + if layer == self.num_layer: + h = F.dropout(h, self.dropout, training=self.training) + else: + h = F.dropout(F.relu(h), self.dropout, training=self.training) + + if self.pool == None: + h = h[data.pool_mask] + else: + h = self.pool(h, data.batch) + + h = self.proj_head(h) + + return self.output_layers(h) + + def load_my_state_dict(self, state_dict): + own_state = self.state_dict() + for name, param in state_dict.items(): + if name not in own_state: + print('NOT LOADED:', name) + continue + if isinstance(param, nn.parameter.Parameter): + # backwards compatibility for serialized parameters + param = param.data + own_state[name].copy_(param) diff --git a/pretrained/ckpt/config.yaml b/pretrained/ckpt/config.yaml new file mode 100644 index 0000000..338270b --- /dev/null +++ b/pretrained/ckpt/config.yaml @@ -0,0 +1,30 @@ +batch_size: 512 +world_size: 3 +backend: nccl +epochs: 50 +eval_every_n_epochs: 1 +resume_from: None +log_every_n_steps: 200 +warmup: 10 +sub_coeff: 0.5 + +optim: + type: Adam + lr: 0.0005 + weight_decay: 0.00001 + +model: + num_layer: 5 + emb_dim: 300 + feat_dim: 512 + dropout: 0 + pool: mean + +dataset: + num_workers: 24 + valid_size: 0 + data_path: data/pubchem-10m-clean.txt + +loss: + temperature: 0.1 + use_cosine_similarity: True diff --git a/pretrained/ckpt/model.pth b/pretrained/ckpt/model.pth new file mode 100644 index 0000000..d858497 Binary files /dev/null and b/pretrained/ckpt/model.pth differ diff --git a/utils/nt_xent.py b/utils/nt_xent.py new file mode 100644 index 0000000..adfdc31 --- /dev/null +++ b/utils/nt_xent.py @@ -0,0 +1,134 @@ +import torch +import numpy as np + + +class NTXentLoss(torch.nn.Module): + + def __init__(self, device, temperature, use_cosine_similarity, **kwargs): + super(NTXentLoss, self).__init__() + self.temperature = temperature + self.device = device + self.softmax = torch.nn.Softmax(dim=-1) + self.similarity_function = self._get_similarity_function(use_cosine_similarity) + self.criterion = torch.nn.CrossEntropyLoss(reduction="sum") + + def _get_similarity_function(self, use_cosine_similarity): + if use_cosine_similarity: + self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1) + return self._cosine_simililarity + else: + return self._dot_simililarity + + def _get_correlated_mask(self, batch_size): + diag = np.eye(2 * batch_size) + l1 = np.eye((2 * batch_size), 2 * batch_size, k=-batch_size) + l2 = np.eye((2 * batch_size), 2 * batch_size, k=batch_size) + mask = torch.from_numpy((diag + l1 + l2)) + mask = (1 - mask).type(torch.bool) + return mask.to(self.device) + + @staticmethod + def _dot_simililarity(x, y): + v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2) + # x shape: (N, 1, C) + # y shape: (1, C, 2N) + # v shape: (N, 2N) + return v + + def _cosine_simililarity(self, x, y): + # x shape: (N, 1, C) + # y shape: (1, 2N, C) + # v shape: (N, 2N) + v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0)) + return v + + def forward(self, zis, zjs): + assert zis.size(0) == zjs.size(0) + batch_size = zis.size(0) + + representations = torch.cat([zjs, zis], dim=0) + + similarity_matrix = self.similarity_function(representations, representations) + + # filter out the scores from the positive samples + l_pos = torch.diag(similarity_matrix, batch_size) + r_pos = torch.diag(similarity_matrix, -batch_size) + positives = torch.cat([l_pos, r_pos]).view(2 * batch_size, 1) + + mask_samples_from_same_repr = self._get_correlated_mask(batch_size).type(torch.bool) + negatives = similarity_matrix[mask_samples_from_same_repr].view(2 * batch_size, -1) + + logits = torch.cat((positives, negatives), dim=1) + logits /= self.temperature + + labels = torch.zeros(2 * batch_size).to(self.device).long() + loss = self.criterion(logits, labels) + + return loss / (2 * batch_size) + + +# import torch +# import numpy as np + + +# class NTXentLoss(torch.nn.Module): + +# def __init__(self, device, batch_size, temperature, use_cosine_similarity, **kwargs): +# super(NTXentLoss, self).__init__() +# self.batch_size = batch_size +# self.temperature = temperature +# self.device = device +# self.softmax = torch.nn.Softmax(dim=-1) +# self.mask_samples_from_same_repr = self._get_correlated_mask().type(torch.bool) +# self.similarity_function = self._get_similarity_function(use_cosine_similarity) +# self.criterion = torch.nn.CrossEntropyLoss(reduction="sum") + +# def _get_similarity_function(self, use_cosine_similarity): +# if use_cosine_similarity: +# self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1) +# return self._cosine_simililarity +# else: +# return self._dot_simililarity + +# def _get_correlated_mask(self): +# diag = np.eye(2 * self.batch_size) +# l1 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=-self.batch_size) +# l2 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=self.batch_size) +# mask = torch.from_numpy((diag + l1 + l2)) +# mask = (1 - mask).type(torch.bool) +# return mask.to(self.device) + +# @staticmethod +# def _dot_simililarity(x, y): +# v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2) +# # x shape: (N, 1, C) +# # y shape: (1, C, 2N) +# # v shape: (N, 2N) +# return v + +# def _cosine_simililarity(self, x, y): +# # x shape: (N, 1, C) +# # y shape: (1, 2N, C) +# # v shape: (N, 2N) +# v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0)) +# return v + +# def forward(self, zis, zjs): +# representations = torch.cat([zjs, zis], dim=0) + +# similarity_matrix = self.similarity_function(representations, representations) + +# # filter out the scores from the positive samples +# l_pos = torch.diag(similarity_matrix, self.batch_size) +# r_pos = torch.diag(similarity_matrix, -self.batch_size) +# positives = torch.cat([l_pos, r_pos]).view(2 * self.batch_size, 1) + +# negatives = similarity_matrix[self.mask_samples_from_same_repr].view(2 * self.batch_size, -1) + +# logits = torch.cat((positives, negatives), dim=1) +# logits /= self.temperature + +# labels = torch.zeros(2 * self.batch_size).to(self.device).long() +# loss = self.criterion(logits, labels) + +# return loss / (2 * self.batch_size) diff --git a/utils/weighted_nt_xent.py b/utils/weighted_nt_xent.py new file mode 100644 index 0000000..0d44a28 --- /dev/null +++ b/utils/weighted_nt_xent.py @@ -0,0 +1,83 @@ +import torch +from torch import nn +import numpy as np +from rdkit import DataStructs, Chem +import torch.nn.functional as F +from rdkit.Chem import AllChem + + +class WeightedNTXentLoss(torch.nn.Module): + def __init__(self, device, temperature=0.1, use_cosine_similarity=True, lambda_1=0.5, **kwargs): + super(WeightedNTXentLoss, self).__init__() + self.temperature = temperature + self.device = device + self.similarity_function = self._get_similarity_function(use_cosine_similarity) + self.lambda_1 = lambda_1 + self.criterion = torch.nn.CrossEntropyLoss(reduction="sum") + + def _get_similarity_function(self, use_cosine_similarity): + if use_cosine_similarity: + self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1) + return self._cosine_simililarity + else: + return self._dot_simililarity + + def _get_correlated_mask(self, batch_size): + diag = np.eye(2 * batch_size) + l1 = np.eye((2 * batch_size), 2 * batch_size, k=-batch_size) + l2 = np.eye((2 * batch_size), 2 * batch_size, k=batch_size) + mask = torch.from_numpy((diag + l1 + l2)) + mask = (1 - mask).type(torch.bool) + return mask.to(self.device) + + @staticmethod + def _dot_simililarity(x, y): + v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2) + # x shape: (N, 1, C) + # y shape: (1, C, 2N) + # v shape: (N, 2N) + return v + + def _cosine_simililarity(self, x, y): + # x shape: (N, 1, C) + # y shape: (1, 2N, C) + # v shape: (N, 2N) + v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0)) + return v + + def forward(self, x1, x2, mols): + assert x1.size(0) == x2.size(0) + batch_size = x1.size(0) + + fp_score = np.zeros((batch_size, batch_size-1)) + fps = [AllChem.GetMorganFingerprint(Chem.AddHs(x), 2, useFeatures=True) for x in mols] + + for i in range(len(mols)): + for j in range(i+1, len(mols)): + fp_sim = DataStructs.TanimotoSimilarity(fps[i], fps[j]) + fp_score[i,j-1] = fp_sim + fp_score[j,i] = fp_sim + + fp_score = 1 - self.lambda_1 * torch.tensor(fp_score, dtype=torch.float).to(x1.device) + fp_score = fp_score.repeat(2, 2) + + representations = torch.cat([x2, x1], dim=0) + + similarity_matrix = self.similarity_function(representations, representations) + + # filter out the scores from the positive samples + l_pos = torch.diag(similarity_matrix, batch_size) + r_pos = torch.diag(similarity_matrix, -batch_size) + positives = torch.cat([l_pos, r_pos]).view(2 * batch_size, 1) + + mask_samples_from_same_repr = self._get_correlated_mask(batch_size).type(torch.bool) + negatives = similarity_matrix[mask_samples_from_same_repr].view(2 * batch_size, -1) + negatives *= fp_score + + logits = torch.cat((positives, negatives), dim=1) + logits /= self.temperature + + labels = torch.zeros(2 * batch_size).to(self.device).long() + loss = self.criterion(logits, labels) + + return loss / (2 * batch_size)