Skip to content

an example implementation of a cnn for predicting bait performance. #9

Open
yfarjoun wants to merge 5 commits into
masterfrom
yf_adding_cnn_for_bait_prediction
Open

an example implementation of a cnn for predicting bait performance. #9
yfarjoun wants to merge 5 commits into
masterfrom
yf_adding_cnn_for_bait_prediction

Conversation

@yfarjoun

Copy link
Copy Markdown
Contributor

It doesn't work very well on the ICE sample I tried it with, probably because the baits are not equimolar, and predicting the molarity is quite difficult. Still, it might be a useful template for someone.

…t doesn't work very well on the ICE sample I tried it with, probably because the baits are not equimolar, and predicting the molarity is quite difficult
@yfarjoun yfarjoun requested a review from lucidtronix August 18, 2017 18:21
@lucidtronix lucidtronix self-assigned this Aug 18, 2017
@mbabadi

mbabadi commented Aug 18, 2017

Copy link
Copy Markdown

Do we know the baits are not equimolar for a fact or just a hunch? if it is a fact, then predicting bait efficiency is obviously impossible, in particular, if the variance in molarity exceeds the context-dependent variance. The relative spatial capture efficiency of each bait, however, is not affected by this confounding factor and may be predictable.

@lucidtronix lucidtronix left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool work! Some nitpicks and suggestions...


import os
import math
# import h5py

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove unused

my_metrics = [metrics.mean_squared_error, rmse_log]

model.compile(loss='mean_squared_error', optimizer=sgd, metrics=my_metrics)
print('model summary:\n', model.summary())

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bug but this should just be
model.summary()
That function already contains a print

# the Input layer and three Dense layers
model = Model(input=[input_baits, input_annotations], output=predictions)
model.compile(loss=gme, optimizer=sgd, metrics=my_metrics)
print('model summary:\n', model.summary())

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model.summary()

activation="relu",
init='normal'))

model.add(MaxPooling1D(pool_length=3, stride=3))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I try to avoid maxpooling this early in the model for genetic data. For images we have a strong smoothness prior which we don't have on DNA sequences.


x = Dropout(0.2)(x)

x = MaxPooling1D(strides=3, pool_size=3)(x)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maxpooling could be risky this early see above.

bait_shape = (args.window_size, len(args.inputs),)
annotation_shape = (len(args.annotations),)

print(bait_shape)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add label or remove


predictions = Dense(units=1, init=RandomNormal(mean=1.0, stddev=0.5, seed=None), activation=None)(xy)

sgd = SGD(lr=0.001, decay=1e-6, momentum=0.9, nesterov=True, clipnorm=0.5)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we discussed maybe try the Adam optimizer.

count = 0
while count < args.samples:
contig_key, start, end, coverage, gc = sample_from_bed(baits_and_coverages)
mid = (start + end) / 2

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change to // if you import future division

grep -v '^@' whole_exome_illumina_coding_v1.Homo_sapiens_assembly19.targets.interval_list | awk 'BEGIN{FS="\t";OFS="\t"}{print $1, $2-1,$3}' > whole_exome_illumina_coding_v1.Homo_sapiens_assembly19.targets.bed
grep -v '^@' whole_exome_illumina_coding_v1.Homo_sapiens_assembly19.baits.interval_list | awk 'BEGIN{FS="\t";OFS="\t"}{print $1, $2-1,$3}' > whole_exome_illumina_coding_v1.Homo_sapiens_assembly19.baits.bed

bedtools slop -i whole_exome_illumina_coding_v1.Homo_sapiens_assembly19.baits.bed -b 250 -g hg19.genome > sloppy.baits.bed

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

love this tool

return c_idx, p_idx


# TODO: make more random (this gives too much power to the small contigs)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes!

- removed unused imports
- changed dropout from .2 to .3
- removed first MaxPooling layer
- added informative labels to print commands
- made division floating point
…ailed misrably...anyone care to see what I did wrong ?
@yfarjoun

Copy link
Copy Markdown
Contributor Author

pushed a few more commits...care to comment?

@lucidtronix lucidtronix left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some small changes and comment cleanup. But the bigger question is why isn't the model learning more? Judging by the scatter_bait_performance_cnn_model.jpg it look like underfitting not overfitting. Training and test performance are similar and training error is far from 0. Is it because there are non-sequence (and non bait position) factors that are responsible for bait coverage? Is the equimolarity assumption wrong? Are the bait positions normalized between 0 and 1? They probably should be. I would be curious how a pure sequence model performs. Another debug idea is to cast this as classification rather than regression, see the model.compile comment on line 210.

contig_sizes = {key: len(bed_dict[key]) for key in bed_dict.keys()}
total_size = sum(contig_sizes.values())

contig_key = np.random.choice(bed_dict.keys(), 1, p=[x / total_size for x in contig_sizes.values()])[0]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

lows = bed_dict[contig][0]
ups = bed_dict[contig][1]

return np.any((lows <= pos) & (pos <= ups))

@lucidtronix lucidtronix Aug 25, 2017

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again my bug which got fixed in vqsr but never cherrypicked back here. I believe this should be:
return np.any((lows <= pos) & (pos < ups))


def gme(y_true, y_pred):
"""calculates the root (geometric) mean squared error of the values."""
return K.exp(K.mean(K.log(K.abs(np.divide(y_true + .001, y_pred + .001) - 1))))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could use K.epsilon() instead of the hardcoded .001, this value is initialized in your ~/.keras/keras.json file and settable via K.set_epsilon(1e-05). Also I think the division operator / is appropriately overloaded to handle this without np.divide()

contig_sizes = {key: len(bed_dict[key]) for key in bed_dict.keys()}
total_size = sum(contig_sizes.values())

contig_key = np.random.choice(bed_dict.keys(), 1, p=[x / total_size for x in contig_sizes.values()])[0]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

# from math import sqrt
#
#
# def put_kernels_on_grid(kernel, pad=1):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be uncommented? Seems like a helpful fxn...

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it might be, but I failed in getting it to work...adding it here in a comment so that folks have something to work with...

model = Model(inputs=[input_baits, input_annotations], outputs=predictions)

# # add some TensorBoard annotations
# conv1d_1 = filter(lambda y: y.name == "conv1d_1", model.layers)[0]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this commented out?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because I coudlnt' get it to work, but I wanted to show it to you to see if you could help!

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, are you trying to visualize the weights in tensorboard or something more complicated?

# filters=put_kernels_on_grid(reshaped, 2)
#
# merged = tf.summary.merge_all()
# train_writer = tf.summary.FileWriter("./log/" + '/train')

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this conflict with the TensorBoard callback?

# train_writer = tf.summary.FileWriter("./log/" + '/train')
#
adamo = Adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-08, clipnorm=1.)
model.compile(loss=metrics.mean_squared_error, optimizer=adamo, metrics=my_metrics)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if we should try this as a classification problem by binning the coverage first and then trying to predict (high, med, low) or (too high, high, medium, low, too low) or something. Then we could use categorical crossentropy as our loss, which tends to have more well-behaved learning dynamics, also we could use categorical accuracy as a metric to get a quick idea of how the model compares to chance...

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and then what? modify the loss so that it tries to maximize xentropy + MSE?

@lucidtronix lucidtronix Aug 28, 2017

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking to just try categorical crossentropy though we could do it a multi-task problem where the model tries to do both. It is easy to do with the functional API. Something like:

regression = Dense(units=1, kernel_initializer=RandomNormal(mean=1.0, stddev=0.5, seed=None), activation="relu")(xy)
classification = Dense(units=5, activation='softmax')(xy)
model = Model(inputs=[input_baits, input_annotations], outputs=[regression, classification])
model.compile(loss=['categorical_crossentropy', metrics.mean_squared_error], optimizer=adamo, metrics=my_metrics)

Assuming 5 bins for the quantized coverage...

@yfarjoun

Copy link
Copy Markdown
Contributor Author

I've chatted with Tim about this, he's pretty sure that the baits are NOT equi-molar...so unless I can get the molarity mixture, I'm surprised that this worked at all....

@lucidtronix

Copy link
Copy Markdown
Contributor

I'm not so surprised that there is some predictive value in the sequence alone.. we've now seen that on a few different tasks: variant filtering, indel modeling, bait performance. Anyway, it seems like it may be difficult to track down the molarity mixture. Should we leave this PR open while you search or do you want to merge?

@yfarjoun

Copy link
Copy Markdown
Contributor Author

I'll fix it up and merge...who knows how long it will take to get the molarity or spike-in list...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants