-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
14 changed files
with
1,970 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# cache | ||
*.vscode | ||
*__pycache__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
## Improving Molecular Contrastive Learning via Faulty Negative Mitigation and Decomposed Fragment Contrast ## | ||
|
||
This is the offical implementation of <strong><em>iMolCLR</em></strong>: ["Improving Molecular Contrastive Learning via Faulty Negative Mitigation and Decomposed Fragment Contrast"](https://arxiv.org/abs/2202.09346). | ||
|
||
## Getting Started | ||
|
||
### Installation | ||
|
||
Set up conda environment and clone the github repo | ||
|
||
``` | ||
# create a new environment | ||
$ conda create --name imolclr python=3.7 | ||
$ conda activate imolclr | ||
# install requirements | ||
$ pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 -f https://download.pytorch.org/whl/torch_stable.html | ||
$ pip install torch-geometric==1.6.3 torch-sparse==0.6.9 torch-scatter==2.0.6 -f https://pytorch-geometric.com/whl/torch-1.7.0+cu110.html | ||
$ pip install PyYAML | ||
$ conda install -c conda-forge rdkit=2021.09.1 | ||
$ conda install -c conda-forge tensorboard | ||
# clone the source code of iMolCLR | ||
$ git clone https://github.com/yuyangw/iMolCLR.git | ||
$ cd iMolCLR | ||
``` | ||
|
||
### Dataset | ||
|
||
You can download the pre-training data and benchmarks used in the paper [here](https://drive.google.com/file/d/1aDtN6Qqddwwn2x612kWz9g0xQcuAtzDE/view?usp=sharing) and extract the zip file under `./data` folder. The data for pre-training can be found in `pubchem-10m-clean.txt`. All the databases for fine-tuning are saved in the folder under the benchmark name. You can also find the benchmarks from [MoleculeNet](https://moleculenet.org/). | ||
|
||
### Pre-training | ||
|
||
To train the iMolCLR, where the configurations are defined in `config.yaml` | ||
``` | ||
$ python imolclr.py | ||
``` | ||
|
||
To monitor the training via tensorboard, run `tensorboard --logdir ckpt/{PATH}` and click the URL http://127.0.0.1:6006/. | ||
|
||
### Fine-tuning | ||
|
||
To fine-tune the iMolCLR pre-trained model on downstream molecular benchmarks, where the configurations are defined in `config_finetune.yaml` | ||
``` | ||
$ python finetune.py | ||
``` | ||
|
||
### Pre-trained model | ||
|
||
We also provide a pre-trained model, which can be found in `ckpt/pretrained`. You can load the model by change the `fine_tune_from` variable in `config_finetune.yaml` to `pretrained`. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
batch_size: 512 # batch size | ||
world_size: 3 # total number of GPUs | ||
backend: nccl # backends of PyTorch | ||
epochs: 50 # total number of epochs | ||
warmup: 10 # warm-up epochs | ||
|
||
eval_every_n_epochs: 1 # validation frequency | ||
resume_from: None # resume training | ||
log_every_n_steps: 200 # print training log frequency | ||
|
||
optim: | ||
lr: 0.0005 # initial learning rate for Adam optimizer | ||
weight_decay: 0.00001 # weight decay for Adam for Adam optimizer | ||
|
||
model: | ||
num_layer: 5 # number of graph conv layers | ||
emb_dim: 300 # embedding dimension in graph conv layers | ||
feat_dim: 512 # output feature dimention | ||
dropout: 0 # dropout ratio | ||
pool: mean # readout pooling (i.e., mean/max/add) | ||
|
||
dataset: | ||
num_workers: 12 # dataloader number of workers | ||
valid_size: 0.05 # ratio of validation data | ||
data_path: data/pubchem-10m-clean.txt # path of pre-training data | ||
|
||
loss: | ||
temperature: 0.1 # temperature of (weighted) NT-Xent loss | ||
use_cosine_similarity: True # whether to use cosine similarity in (weighted) NT-Xent loss (i.e. True/False) | ||
lambda_1: 0.5 # $\lambda_1$ to control faulty negative mitigation | ||
lambda_2: 0.5 # $\lambda_2$ to control fragment contrast |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
batch_size: 32 # batch size | ||
epochs: 100 # total number of epochs | ||
eval_every_n_epochs: 1 # validation frequency | ||
fine_tune_from: pretrained # directory of pre-trained model | ||
log_every_n_steps: 50 # print training log frequency | ||
gpu: cuda:0 # training GPU | ||
task_name: BBBP # name of fine-tuning benchmark, inlcuding | ||
# classifications: BBBP/BACE/ClinTox/Tox21/HIV/SIDER/MUV | ||
# regressions: FreeSolv/ESOL/Lipo/qm7/qm8 | ||
|
||
optim: | ||
lr: 0.0005 # initial learning rate for the prediction head | ||
weight_decay: 0.000001 # weight decay of Adam | ||
base_ratio: 0.4 # ratio of learning rate for the base GNN encoder | ||
|
||
model: # notice that other 'model' variables are defined from the config of pretrained model | ||
drop_ratio: 0.3 # dropout ratio | ||
pool: mean # readout pooling (i.e., mean/max/add) | ||
|
||
dataset: | ||
num_workers: 4 # dataloader number of workers | ||
valid_size: 0.1 # ratio of validation data | ||
test_size: 0.1 # ratio of test data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
import os | ||
import csv | ||
import math | ||
import time | ||
import random | ||
import networkx as nx | ||
import numpy as np | ||
from copy import deepcopy | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
from torch.utils.data import Dataset, DataLoader | ||
from torch.utils.data.sampler import SubsetRandomSampler | ||
import torchvision.transforms as transforms | ||
|
||
from torch_scatter import scatter | ||
from torch_geometric.data import Data, Batch | ||
|
||
import rdkit | ||
from rdkit import Chem | ||
from rdkit.Chem.rdchem import HybridizationType | ||
from rdkit.Chem.rdchem import BondType as BT | ||
from rdkit.Chem import AllChem | ||
|
||
|
||
ATOM_LIST = list(range(1,119)) | ||
CHIRALITY_LIST = [ | ||
Chem.rdchem.ChiralType.CHI_UNSPECIFIED, | ||
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW, | ||
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW, | ||
Chem.rdchem.ChiralType.CHI_OTHER | ||
] | ||
BOND_LIST = [ | ||
BT.SINGLE, | ||
BT.DOUBLE, | ||
BT.TRIPLE, | ||
BT.AROMATIC | ||
] | ||
BONDDIR_LIST = [ | ||
Chem.rdchem.BondDir.NONE, | ||
Chem.rdchem.BondDir.ENDUPRIGHT, | ||
Chem.rdchem.BondDir.ENDDOWNRIGHT | ||
] | ||
|
||
|
||
def read_smiles(data_path): | ||
smiles_data = [] | ||
with open(data_path) as csv_file: | ||
csv_reader = csv.reader(csv_file, delimiter=',') | ||
for i, row in enumerate(csv_reader): | ||
smiles = row[-1] | ||
smiles_data.append(smiles) | ||
# mol = Chem.MolFromSmiles(smiles) | ||
# if mol != None: | ||
# smiles_data.append(smiles) | ||
return smiles_data | ||
|
||
|
||
def removeSubgraph(Graph, center, percent=0.2): | ||
assert percent <= 1 | ||
G = Graph.copy() | ||
num = int(np.floor(len(G.nodes)*percent)) | ||
removed = [] | ||
temp = [center] | ||
|
||
while len(removed) < num: | ||
neighbors = [] | ||
for n in temp: | ||
neighbors.extend([i for i in G.neighbors(n) if i not in temp]) | ||
for n in temp: | ||
if len(removed) < num: | ||
G.remove_node(n) | ||
removed.append(n) | ||
else: | ||
break | ||
temp = list(set(neighbors)) | ||
return G, removed | ||
|
||
|
||
class MoleculeDataset(Dataset): | ||
def __init__(self, smiles_data): | ||
super(Dataset, self).__init__() | ||
self.smiles_data = smiles_data | ||
|
||
def __getitem__(self, index): | ||
mol = Chem.MolFromSmiles(self.smiles_data[index]) | ||
mol = Chem.AddHs(mol) | ||
|
||
N = mol.GetNumAtoms() | ||
M = mol.GetNumBonds() | ||
|
||
type_idx = [] | ||
chirality_idx = [] | ||
atomic_number = [] | ||
atoms = mol.GetAtoms() | ||
bonds = mol.GetBonds() | ||
# Sample 2 different centers to start for i and j | ||
start_i, start_j = random.sample(list(range(N)), 2) | ||
|
||
# Construct the original molecular graph from edges (bonds) | ||
edges = [] | ||
for bond in bonds: | ||
edges.append([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()]) | ||
molGraph = nx.Graph(edges) | ||
|
||
# Get the graph for i and j after removing subgraphs | ||
percent_i, percent_j = 0.25, 0.25 | ||
G_i, removed_i = removeSubgraph(molGraph, start_i, percent_i) | ||
G_j, removed_j = removeSubgraph(molGraph, start_j, percent_j) | ||
|
||
for atom in atoms: | ||
type_idx.append(ATOM_LIST.index(atom.GetAtomicNum())) | ||
chirality_idx.append(CHIRALITY_LIST.index(atom.GetChiralTag())) | ||
atomic_number.append(atom.GetAtomicNum()) | ||
|
||
x1 = torch.tensor(type_idx, dtype=torch.long).view(-1,1) | ||
x2 = torch.tensor(chirality_idx, dtype=torch.long).view(-1,1) | ||
x = torch.cat([x1, x2], dim=-1) | ||
# x shape (N, 2) [type, chirality] | ||
|
||
# Mask the atoms in the removed list | ||
x_i = deepcopy(x) | ||
for atom_idx in removed_i: | ||
# Change atom type to 118, and chirality to 0 | ||
x_i[atom_idx,:] = torch.tensor([len(ATOM_LIST), 0]) | ||
x_j = deepcopy(x) | ||
for atom_idx in removed_j: | ||
# Change atom type to 118, and chirality to 0 | ||
x_j[atom_idx,:] = torch.tensor([len(ATOM_LIST), 0]) | ||
|
||
# Only consider bond still exist after removing subgraph | ||
row_i, col_i, row_j, col_j = [], [], [], [] | ||
edge_feat_i, edge_feat_j = [], [] | ||
G_i_edges = list(G_i.edges) | ||
G_j_edges = list(G_j.edges) | ||
for bond in mol.GetBonds(): | ||
start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() | ||
feature = [ | ||
BOND_LIST.index(bond.GetBondType()), | ||
BONDDIR_LIST.index(bond.GetBondDir()) | ||
] | ||
if (start, end) in G_i_edges: | ||
row_i += [start, end] | ||
col_i += [end, start] | ||
edge_feat_i.append(feature) | ||
edge_feat_i.append(feature) | ||
if (start, end) in G_j_edges: | ||
row_j += [start, end] | ||
col_j += [end, start] | ||
edge_feat_j.append(feature) | ||
edge_feat_j.append(feature) | ||
|
||
edge_index_i = torch.tensor([row_i, col_i], dtype=torch.long) | ||
edge_attr_i = torch.tensor(np.array(edge_feat_i), dtype=torch.long) | ||
edge_index_j = torch.tensor([row_j, col_j], dtype=torch.long) | ||
edge_attr_j = torch.tensor(np.array(edge_feat_j), dtype=torch.long) | ||
|
||
data_i = Data(x=x_i, edge_index=edge_index_i, edge_attr=edge_attr_i) | ||
data_j = Data(x=x_j, edge_index=edge_index_j, edge_attr=edge_attr_j) | ||
|
||
return data_i, data_j, mol | ||
|
||
def __len__(self): | ||
return len(self.smiles_data) | ||
|
||
|
||
def collate_fn(batch): | ||
gis, gjs, mols = zip(*batch) | ||
|
||
gis = Batch().from_data_list(gis) | ||
gjs = Batch().from_data_list(gjs) | ||
|
||
return gis, gjs, mols | ||
|
Oops, something went wrong.