Skip to content

Using the trained model #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
fteufel opened this issue Feb 3, 2023 · 0 comments
Open

Using the trained model #2

fteufel opened this issue Feb 3, 2023 · 0 comments

Comments

@fteufel
Copy link

fteufel commented Feb 3, 2023

Hi @zatchwu I want to use your trained model to sample/score signal peptides.

The following is what I came up with by going through the provided notebooks and trying to get a more straightforward sequence -> model -> prediction worflow independent of the datasets you were using. It would be great to get some feedback whether what I'm doing here is correct.

  1. Encoding amino acid data for the transformer
SPGEN_AA_TO_ID = {
 ' ': 0,
 '$': 1,
 '.': 2,
 'A': 3,
 'C': 4,
 'D': 5,
 'E': 6,
 'F': 7,
 'G': 8,
 'H': 9,
 'I': 10,
 'K': 11,
 'L': 12,
 'M': 13,
 'N': 14,
 'P': 15,
 'Q': 16,
 'R': 17,
 'S': 18,
 'T': 19,
 'U': 20,
 'V': 21,
 'W': 22,
 'X': 23,
 'Y': 24,
 'Z': 25,
}

sp = [SPGEN_AA_TO_ID['$']] + [SPGEN_AA_TO_ID[x] for x in sp] + [SPGEN_AA_TO_ID['.']]
prot = [SPGEN_AA_TO_ID['$']] + [SPGEN_AA_TO_ID[x] for x in prot] + [SPGEN_AA_TO_ID['.']]
  1. Loading the model
def load_spgen_model():
    # the weights were extracted from the .chkpt file with the same name
    state_dict = torch.load('../../SPGen/remote_generation/signal_peptide/outputs/SIM99_550_12500_64_6_5_0.1_64_100_0.0001_-0.03_99_weightsonly.pt')
    model = Models.Transformer(
        27,
        27,
        107,
        proj_share_weight=True,
        embs_share_weight=True,
        d_k=64,
        d_v=64,
        d_model=550,
        d_word_vec=550,
        d_inner_hid=1100,
        n_layers=6,
        n_head=5,
        dropout=0.1)

    model.load_state_dict(state_dict)
    model.eval()

    return model
  1. Making predictions (logits) and scoring the perplexity. I encode the data as shown in step 1, and make prot_positions, sp_positions masks that are 0 at true positions and 1 at masked positions.
def get_perplexity_batch(transformer, src_seq, src_positions, tgt_seq, tgt_positions):
    '''Adapted from Translator()._epoch().'''
    ppls = []

    loss_fn = torch.nn.CrossEntropyLoss()

    pred = transformer((src_seq, src_positions), (tgt_seq, tgt_positions))

    # process each seq in batch
    for idx in range(len(src_seq)):
        loss = loss_fn(pred[idx].view(-1, 27), tgt_seq[idx,1:].view(-1))
        ppls.append(torch.exp(loss).item())

    return ppls


def predict_spgen(model, loader):
        
    with torch.no_grad():
        ppl = []
        for idx, batch in tqdm(enumerate(loader), total=len(loader)):

            proteins, prot_positions, sps, sp_positions = batch
            proteins, prot_positions, sps, sp_positions = proteins.to(device), prot_positions.to(device), sps.to(device), sp_positions.to(device)

            aa_logits = model((proteins,prot_positions), (sps, sp_positions))

            ppls = get_perplexity_batch(model, proteins, prot_positions, sps, sp_positions)

            ppl.extend(ppls)

    return np.array(ppl)

My code is running, but it is a bit hard to tell whether everything is in place or there's an error somewhere. Would be great to get some feedback - also open to any other way to make the model run on new data.

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant