Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Moltransf #85

Draft
wants to merge 22 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
388734d
Extend JTNNVAE to support WAE,users have the choice of VAE or WAE in …
samadejacobs Apr 2, 2019
464001a
Extend JTNNVAE to support WAE,users have the choice of VAE or WAE in …
samadejacobs Apr 2, 2019
4f4d31c
Extend CharRNN to load LBANN trained models
samadejacobs Oct 16, 2019
6b881f7
Fix typos
samadejacobs Oct 16, 2019
9e7bdc6
fix typos, import numpy etc
samadejacobs Oct 16, 2019
b7c1a68
add vocab and config for char RNN
samadejacobs Oct 17, 2019
79db018
More clean up and strict check
samadejacobs Oct 18, 2019
161e892
adds preprocessing script to featurize arbitrary smiles datasets. mak…
wderekjones Nov 22, 2019
1527153
adding support for TF 2.0 to moses/metrics/utils_fcd.py
wderekjones Nov 25, 2019
42bc7a7
adding utilities to preprocess arbitrary datasets, both python and sl…
wderekjones Feb 29, 2020
9a4f560
adding scripts to prepare datasets
wderekjones Mar 21, 2020
28932ce
addressing issues related to structure of example slurm data processi…
wderekjones Mar 25, 2020
ac132b0
Merge pull request #1 from samadejacobs/integrate_lbann_moses
samadejacobs Mar 30, 2020
c59e4f4
Update model in response to changes in LBANN naming of dumping weights
samadejacobs Jun 26, 2020
6e42081
Make epoch count optional especially if loading 'final' model, one le…
samadejacobs Jul 14, 2020
9fbfefb
Support for loading pretrained LBANN VAE model
samadejacobs Aug 20, 2020
129d9e3
Code clean up
samadejacobs Aug 20, 2020
8a01982
Add small version of Zinc train and test data
samadejacobs Aug 20, 2020
0035ff7
Add ckpt for zinc10K
samadejacobs Aug 20, 2020
1fda9a0
Support renaming of weights in LBANN, support evaluation of both reco…
samadejacobs Sep 24, 2020
1871098
Draft implementation of molecular transformer (GPT2) model
samadejacobs Sep 24, 2020
648766d
Add gpt to train and sample drivers
samadejacobs Sep 27, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10,001 changes: 10,001 additions & 0 deletions data/zinc/test.csv

Large diffs are not rendered by default.

10,001 changes: 10,001 additions & 0 deletions data/zinc/test_scaffolds.csv

Large diffs are not rendered by default.

Binary file added data/zinc/test_scaffolds_stats.npz
Binary file not shown.
Binary file added data/zinc/test_stats.npz
Binary file not shown.
10,000 changes: 10,000 additions & 0 deletions data/zinc/train10k.csv

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion moses/char_rnn/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def get_parser(parser=None):

# Model
model_arg = parser.add_argument_group('Model')
model_arg.add_argument("--num_layers", type=int, default=3,
model_arg.add_argument("--num_layers", type=int, default=1,
help="Number of LSTM layers")
model_arg.add_argument("--hidden", type=int, default=768,
help="Hidden size")
Expand Down
31 changes: 29 additions & 2 deletions moses/char_rnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.rnn as rnn_utils

import numpy as np
import glob

class CharRNN(nn.Module):

Expand All @@ -16,7 +17,7 @@ def __init__(self, vocabulary, config):
self.vocab_size = self.input_size = self.output_size = len(vocabulary)

self.embedding_layer = nn.Embedding(self.vocab_size, self.vocab_size, padding_idx=vocabulary.pad)
self.lstm_layer = nn.LSTM(self.input_size, self.hidden_size, self.num_layers, dropout=self.dropout,
self.lstm_layer = nn.GRU(self.input_size, self.hidden_size, self.num_layers, dropout=self.dropout,
batch_first=True)
self.linear_layer = nn.Linear(self.hidden_size, self.output_size)

Expand Down Expand Up @@ -46,6 +47,32 @@ def tensor2string(self, tensor):

return string

def load_lbann_weights(self,weights_dir,epoch_count=None):

if epoch_count is None:
epoch_count = '*'

with torch.no_grad():
#Load Embedding weights
emb_weights = np.loadtxt(glob.glob(weights_dir+"*.epoch."+str(epoch_count)+"-emb_matrix-Weights.txt")[0])
self.embedding_layer.weight.data.copy_(torch.from_numpy(np.transpose(emb_weights)))

#Load LSTM weights/biases
param_idx = ['_ih_matrix','_hh_matrix','_ih_bias', '_hh_bias']
for l in range(self.num_layers):
for idx, val in enumerate(param_idx):
param_tensor = np.loadtxt(glob.glob(weights_dir+"*.epoch."+str(epoch_count)+"*-gru"+str(l+1)+val+"-Weights.txt")[0])
self.lstm_layer.all_weights[l][idx].copy_(torch.from_numpy(param_tensor))

#Load Linear layer weights/biases
linear_layer_weights = np.loadtxt(glob.glob(weights_dir+"*.epoch."+str(epoch_count)+"*-fcmodule"+str(2*self.num_layers+1)+"_matrix-Weights.txt")[0])
self.linear_layer.weight.data.copy_(torch.from_numpy(linear_layer_weights))
linear_layer_bias = np.loadtxt(glob.glob(weights_dir+"*.epoch."+str(epoch_count)+"*-fcmodule"+str(2*self.num_layers+1)+"_bias-Weights.txt")[0])
self.linear_layer.bias.data.copy_(torch.from_numpy(linear_layer_bias))

print("DONE loading LBANN weights ")
return

def sample(self, n_batch, max_length=100):
with torch.no_grad():
starts = [torch.tensor([self.vocabulary.bos], dtype=torch.long, device=self.device)
Expand Down
4 changes: 4 additions & 0 deletions moses/junction_tree/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ def get_parser(parser=None):
help='Epoch to init KL weight (start from 0)')
model_arg.add_argument('--kl_w', type=float, default=5e-3,
help='KL weight value')
model_arg.add_argument('--latent_model', type=str, default='vae', choices=['vae','wae'],
help='Latent space generative model; vae, wae etc')
model_arg.add_argument('--discriminator_layers', nargs='+', type=int, default=[640, 256,1],
help='Numbers of features for linear layers in WAE discriminator')

# Train
train_arg = parser.add_argument_group('Training')
Expand Down
64 changes: 53 additions & 11 deletions moses/junction_tree/jtnn/jtnn_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from .mol_tree import MolTree
from .mpn import MPN, mol2graph


def set_batch_node_id(mol_batch, vocab):
tot = 0
for mol_tree in mol_batch:
Expand All @@ -21,6 +20,28 @@ def set_batch_node_id(mol_batch, vocab):
node.wid = vocab.get_index(node.smiles)
tot += 1

class Discriminator(nn.Module):
def __init__(self, input_size, layers):
super(Discriminator, self).__init__()

in_features = [input_size] + layers
out_features = layers + [1]

self.layers_seq = nn.Sequential()
for k, (i, o) in enumerate(zip(in_features, out_features)):
self.layers_seq.add_module('linear_{}'.format(k), nn.Linear(i, o))
if k != len(layers):
self.layers_seq.add_module('activation_{}'.format(k), nn.ELU(inplace=True))

def forward(self, x):
return self.layers_seq(x)

def compute_wae_loss(d_prior,D_fake,device):
D_loss_prior = nn.BCEWithLogitsLoss()(d_prior, torch.ones(d_prior.size()[0],1,device=device))
D_loss_sample = nn.BCEWithLogitsLoss()(d_sample, torch.zeros(d_sample.size()[0],1,device=device))
D_loss = D_loss_prior + D_loss_sample
G_adv = nn.BCEWithLogitsLoss()(d_sample, torch.ones(d_sample.size()[0],1,device=device))
return D_loss,G_adv

class JTNNVAE(nn.Module):

Expand All @@ -29,6 +50,7 @@ def __init__(self, vocab, config):
self.vocab = vocab
self.hidden_size = config.hidden
self.latent_size = config.latent
self.latent_model = config.latent_model #vae, wae etc
self.depth = config.depth

self.embedding = nn.Embedding(vocab.size(), self.hidden_size)
Expand All @@ -42,6 +64,9 @@ def __init__(self, vocab, config):
self.T_var = nn.Linear(self.hidden_size, self.latent_size // 2)
self.G_mean = nn.Linear(self.hidden_size, self.latent_size // 2)
self.G_var = nn.Linear(self.hidden_size, self.latent_size // 2)
###todo: fix to be cat[T_mean and G_mean]
disc_inp_sz = 2*self.hidden_size + config.latent
self.discriminator = Discriminator(disc_inp_sz, config.discriminator_layers)

self.assm_loss = nn.CrossEntropyLoss(reduction='sum')
self.stereo_loss = nn.CrossEntropyLoss(reduction='sum')
Expand All @@ -67,22 +92,35 @@ def encode(self, mol_batch):
smiles_batch = [mol_tree.smiles for mol_tree in mol_batch]
mol_vec = self.mpn(mol2graph(smiles_batch, device=device))
return tree_mess, tree_vec, mol_vec


def discriminator_forward(self, *args, **kwargs):
return self.discriminator(*args, **kwargs)

def forward(self, mol_batch, beta=0):
device = self.device

##Initialize losses to 0
metric = loss = kl_loss = d_loss = adv_loss = torch.tensor(0.0,device=device)
batch_size = len(mol_batch)

tree_mess, tree_vec, mol_vec = self.encode(mol_batch)

tree_mean = self.T_mean(tree_vec)
tree_log_var = -torch.abs(self.T_var(tree_vec))
mol_mean = self.G_mean(mol_vec)
mol_log_var = -torch.abs(self.G_var(mol_vec))

z_mean = torch.cat([tree_mean, mol_mean], dim=1)
z_log_var = torch.cat([tree_log_var, mol_log_var], dim=1)
kl_loss = -0.5 * torch.sum(1.0 + z_log_var - z_mean * z_mean - torch.exp(z_log_var)) / batch_size

if(self.latent_model == "vae"):
z_log_var = torch.cat([tree_log_var, mol_log_var], dim=1)
kl_loss = -0.5 * torch.sum(1.0 + z_log_var - z_mean * z_mean - torch.exp(z_log_var)) / batch_size

if(self.latent_model == "wae"):
#WAE latent space
d_sample = torch.cat([tree_vec,mol_vec,z_mean], dim=1)
disc_out1 = self.discriminator_forward(d_sample)
epsilon = torch.randn(batch_size, z_mean.size()[1], device=device)
d_prior = torch.cat([tree_vec,mol_vec,epsilon], dim=1)
disc_out2 = self.discriminator_forward(d_prior)
d_loss,adv_loss = compute_wae_loss(disc_out2,disc_out1,device=device)

epsilon = torch.randn(batch_size, self.latent_size // 2, device=device)
tree_vec = tree_mean + torch.exp(tree_log_var / 2) * epsilon
Expand All @@ -92,10 +130,14 @@ def forward(self, mol_batch, beta=0):
word_loss, topo_loss, word_acc, topo_acc = self.decoder(mol_batch, tree_vec)
assm_loss, assm_acc = self.assm(mol_batch, mol_vec, tree_mess)
stereo_loss, stereo_acc = self.stereo(mol_batch, mol_vec)

loss = word_loss + topo_loss + assm_loss + 2 * stereo_loss + beta * kl_loss

return loss, kl_loss.item(), word_acc, topo_acc, assm_acc, stereo_acc
if(self.latent_model == "vae"):
loss = word_loss + topo_loss + assm_loss + 2 * stereo_loss + beta * kl_loss
metric = loss, kl_loss.item(), word_acc, topo_acc, assm_acc, stereo_acc
if(self.latent_model == "wae"):
loss = word_loss + topo_loss + assm_loss + 2 * stereo_loss + beta * adv_loss + d_loss
metric = loss, adv_loss.item(), word_acc, topo_acc, assm_acc, stereo_acc

return metric

def assm(self, mol_batch, mol_vec, tree_mess):
device = self.device
Expand Down
11 changes: 4 additions & 7 deletions moses/junction_tree/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.optim as optim

from tqdm import tqdm
import time

from moses.utils import mapper, Logger
from moses.interfaces import MosesTrainer
Expand All @@ -17,17 +18,15 @@ def _train_epoch(self, model, tqdm_data, epoch, optimizer=None):
model.eval()
else:
model.train()

postfix = { 'word_acc' : 0,
'topo_acc' : 0,
'assm_acc' : 0,
'steo_acc' : 0,
'kl' : 0,}

'latent_loss': 0,}
kl_w = 0 if epoch < self.config.kl_start else self.config.kl_w

for i, batch in enumerate(tqdm_data):
loss, kl_div, wacc, tacc, sacc, dacc = model(batch, kl_w)
loss, latent_loss, wacc, tacc, sacc, dacc = model(batch, kl_w)

if optimizer is not None:
optimizer.zero_grad()
Expand All @@ -38,7 +37,7 @@ def _train_epoch(self, model, tqdm_data, epoch, optimizer=None):
postfix['topo_acc'] += (tacc * 100 - postfix['topo_acc']) / (i + 1)
postfix['assm_acc'] += (sacc * 100 - postfix['assm_acc']) / (i + 1)
postfix['steo_acc'] += (dacc * 100 - postfix['steo_acc']) / (i + 1)
postfix['kl'] += (kl_div - postfix['kl']) / (i + 1)
postfix['latent_loss'] += (latent_loss - postfix['latent_loss']) / (i + 1)

tqdm_data.set_postfix(postfix)

Expand Down Expand Up @@ -102,10 +101,8 @@ def collate(smiles):

def fit(self, model, train_data, val_data=None):
logger = Logger() if self.config.log_file is not None else None

train_loader = self.get_dataloader(model, train_data, shuffle=True)
val_loader = None if val_data is None else self.get_dataloader(model, val_data, shuffle=False)

self._train(model, train_loader, val_loader, logger)

return model
12 changes: 8 additions & 4 deletions moses/metrics/utils_fcd.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,16 @@ def get_predictions(smiles, gpu=-1, batch_size=128):
device = "/gpu:{}".format(gpu)
else:
device = "/cpu"
config = tf.ConfigProto(allow_soft_placement=True)
#config = tf.ConfigProto(allow_soft_placement=True)
config = tf.compat.v1.ConfigProto(allow_soft_placement=True)
config.gpu_options.allow_growth = True
with tf.device(device):
sess = tf.Session(config=config)
set_session(sess)
K.clear_session()
#sess = tf.Session(config=config)
sess = tf.compat.v1.Session(config=config)
#set_session(sess)
tf.compat.v1.keras.backend.set_session(sess)
#K.clear_session()
tf.keras.backend.clear_session()
model = load_ref_model(model_path)
smiles_act = model.predict_generator(
myGenerator_predict(smiles, batch_size=batch_size),
Expand Down
16 changes: 16 additions & 0 deletions moses/script_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,28 @@ def add_sample_args(parser):
common_arg.add_argument('--gen_save',
type=str, required=True,
help='Where to save the gen molecules')
common_arg.add_argument('--pred_save',
type=str, required=False,
help='Where to save the reconstructed molecules')
common_arg.add_argument("--n_batch",
type=int, default=32,
help="Size of batch")
common_arg.add_argument("--max_len",
type=int, default=100,
help="Max of length of SMILES")
common_arg.add_argument('--lbann_weights_dir', type=str, default='',
help='Directory for LBANN weights for inference')
common_arg.add_argument('--lbann_epoch_counts', type=int, default=30,
help='LBANN epoch count at which to load trained model')
common_arg.add_argument('--test_path',
type=str, required=True,
help='Input data in csv format for reconstruction')
common_arg.add_argument('--save_reconstruction', action='store_true', required=False)
'''
common_arg.add_argument("--save_reconstruction",
type=bool, default=False, required=False,
help="Optional flag to save reconstructed test data")
'''

return parser

Expand Down
Loading