Skip to content

Commit

Permalink
adding mcmc training for log/exp
Browse files Browse the repository at this point in the history
  • Loading branch information
elena-buscaroli committed Jan 31, 2025
1 parent cc2a057 commit 9d4d29a
Show file tree
Hide file tree
Showing 3 changed files with 192 additions and 55 deletions.
22 changes: 19 additions & 3 deletions R/growth_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

fit_growth_rates = function(x,
steps=500,
warmup_steps=500,
highlight=c(),
timepoints_to_int=c(),
growth_model="exp.log",
Expand Down Expand Up @@ -74,6 +75,7 @@ fit_growth_rates = function(x,
parents=parents,
growth_model=growth_model,
steps=steps,
warmup_steps=warmup_steps,
timepoints_to_int=timepoints_to_int,
py_pkg=py_pkg)

Expand All @@ -95,6 +97,7 @@ fit_growth_utils = function(rates.df,
parents,
growth_model,
steps,
warmup_steps,
timepoints_to_int,
py_pkg) {

Expand All @@ -105,6 +108,7 @@ fit_growth_utils = function(rates.df,
parents=parents,
growth_model=growth_model,
steps=steps,
warmup_steps=warmup_steps,
clonal=TRUE,
timepoints_to_int=timepoints_to_int,
py_pkg=py_pkg)
Expand All @@ -119,6 +123,7 @@ fit_growth_utils = function(rates.df,
parents=parents,
growth_model=growth_model,
steps=steps,
warmup_steps=warmup_steps,
clonal=FALSE,
timepoints_to_int=timepoints_to_int,
py_pkg=py_pkg)
Expand All @@ -133,6 +138,7 @@ fit_growth_clones = function(rates.df,
parents,
growth_model,
steps,
warmup_steps,
clonal,
timepoints_to_int,
py_pkg=NULL) {
Expand All @@ -152,6 +158,7 @@ fit_growth_clones = function(rates.df,
clonal=clonal,
growth_model=growth_model,
steps=steps,
warmup_steps=warmup_steps,
py_pkg=py_pkg)

return(rates.df)
Expand All @@ -163,6 +170,7 @@ run_py_growth = function(rates.df,
cluster,
timepoints_to_int,
steps=500,
warmup_steps=500,
# p.rates=list("exp"=NULL, "log"=NULL),
clonal=FALSE,
growth_model="exp.log",
Expand Down Expand Up @@ -204,15 +212,21 @@ run_py_growth = function(rates.df,
if (grepl("exp", growth_model)) { # exp training
if (!is.null(p.rates[["exp"]])) p.rate.exp = torch$tensor(p.rates[["exp"]])$float()

losses.exp = x.reg$train(regr="exp", p_rate=p.rate.exp, steps=as.integer(steps), random_state=as.integer(random_state))
# losses.exp = x.reg$train(regr="exp", p_rate=p.rate.exp, steps=as.integer(steps), random_state=as.integer(random_state))
posterior_samples.exp = x.reg$train_mcmc(regr="exp", p_rate=p.rate.log, num_samples=as.integer(steps),
warmup_steps=as.integer(warmup_steps), num_chains=as.integer(1),
random_state=as.integer(random_state))
p.exp = x.reg$get_learned_params()
ll.exp = x.reg$compute_log_likelihood() %>% setNames(nm=lineages)
}

if (grepl("log", growth_model)) { # log training
if (!is.null(p.rates[["log"]])) p.rate.log = torch$tensor(p.rates[["log"]])$float()

losses.log = x.reg$train(regr="log", p_rate=p.rate.log, steps=as.integer(steps), random_state=as.integer(random_state))
# losses.log = x.reg$train(regr="log", p_rate=p.rate.log, steps=as.integer(steps), random_state=as.integer(random_state))
posterior_samples.log = x.reg$train_mcmc(regr="log", p_rate=p.rate.log, num_samples=as.integer(steps),
warmup_steps=as.integer(warmup_steps), num_chains=as.integer(1),
random_state=as.integer(random_state))
p.log = x.reg$get_learned_params()
ll.log = x.reg$compute_log_likelihood() %>% setNames(nm=lineages)
}
Expand All @@ -221,7 +235,9 @@ run_py_growth = function(rates.df,
rates.exp=p.exp,
rates.log=p.log,
lineages=pop_df.cl$Lineage %>% unique(),
cluster=cluster)
cluster=cluster,
posterior_samples.exp=posterior_samples.exp,
posterior_samples.log=posterior_samples.log)

best = c()
for (ll in lineages)
Expand Down
99 changes: 80 additions & 19 deletions R/plots_growth.R
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ plot_growth_regression = function(x,
pl = pop_df %>%
ggplot() +

geom_point(aes(x=Generation, y=Population), alpha=.5, size=.7) +
# geom_point(aes(x=Generation, y=Population), alpha=.5, size=.7) +
geom_point(aes(x=Generation, y=Frequency), alpha=.5, size=.7) +

geom_line(data=filter(regr.df, type==best_model), aes(x=x, y=y, color=type), size=.7, alpha=.9) +

Expand Down Expand Up @@ -92,36 +93,96 @@ get_regression_df = function(x, pop_df, highlight) {
n_lins = rates$Lineage %>% unique() %>% length()
n_cls = length(highlight)

compute_credint = function(post_samples, p_rate, x_values, type) {
post_fitness = post_samples$post_fitness[[1]]
post_rate = dplyr::case_when(
is.null(p_rate) | is.na(p_rate) ~ post_fitness,
.default = p_rate * (1+post_fitness)
)
post_init_t = dplyr::case_when(
is.na(post_samples$post_init_time[[1]]) ~ 0,
.default = post_samples$post_init_time[[1]]
)

if (type == "exp") {
args = post_rate * (x_values - post_init_t)
return( exp( args ) )
}

post_carr_capacity = post_samples$post_carr_capacity[[1]]
args = - post_rate * (x_values - post_init_t)
return( post_carr_capacity / ( 1 + (post_carr_capacity - 1) * exp(args) ) )
}

regr_df = rates %>%
dplyr::mutate(x=list(tmin:tmax)) %>%
tidyr::unnest(x) %>%
dplyr::mutate(sigma=unlist(sigma)[as.character(x)],
args=as.numeric(NA),
y=as.numeric(NA),
y.min=as.numeric(NA),
y.max=as.numeric(NA)) %>%
dplyr::arrange(x, Identity) %>%

dplyr::rowwise() %>%
dplyr::mutate(
args=dplyr::case_when(
type=="log" ~ -rate*(x-init_t),
type=="exp" ~ rate*(x-init_t)
),
y=dplyr::case_when(
type=="log" ~ K / ( 1 + (K-1)*exp(args) ),
type=="exp" ~ exp(args)
)
) %>%
dplyr::mutate(y_credint=list(compute_credint(posterior_samples, p_rate, x, type))) %>%
dplyr::mutate(y.min=quantile(y_credint, 0.05),
y.max=quantile(y_credint, 0.95)) %>%

dplyr::mutate(args=replace(args, type=="log", -rate*(x-init_t)),
args=replace(args, type=="exp", rate*(x-init_t)),

y=replace(y, type=="log", K / ( 1 + (K-1)*exp(args) ) ),
y=replace(y, type=="exp", exp(args)),

y.min=replace(y.min, type=="log", max(0, K / ( 1 + (K-1)*exp(args) ) - sigma)),
y.min=replace(y.min, type=="exp", exp(args - sigma)),

y.max=replace(y.max, type=="log", K / ( 1 + (K-1)*exp(args) ) + sigma),
y.max=replace(y.max, type=="exp", exp(args + sigma)) ) %>%
dplyr::ungroup() %>%

dplyr::select(Lineage, Identity, type, best_model, init_t, x, y, y.min, y.max) %>%
dplyr::mutate(type=ifelse(type=="log", "Logistic", "Exponential")) %>%
dplyr::mutate(best_model=ifelse(best_model=="log", "Logistic", "Exponential")) %>%
dplyr::mutate(Identity=factor(Identity, levels=highlight))



# regr_df = rates %>%
# dplyr::mutate(x=list(tmin:tmax)) %>%
# tidyr::unnest(x) %>%
# dplyr::mutate(sigma=unlist(sigma)[as.character(x)],
# args=as.numeric(NA),
# y=as.numeric(NA),
# y.min=as.numeric(NA),
# y.max=as.numeric(NA)) %>%
# dplyr::arrange(x, Identity) %>%
#
# dplyr::rowwise() %>%
#
# dplyr::mutate(args=replace(args, type=="log", -rate*(x-init_t)),
# args=replace(args, type=="exp", rate*(x-init_t)),
#
# y=replace(y, type=="log", K / ( 1 + (K-1)*exp(args) ) ),
# y=replace(y, type=="exp", exp(args)),
#
# y.min=replace(y.min, type=="log", max(0, K / ( 1 + (K-1)*exp(args) ) - sigma)),
# y.min=replace(y.min, type=="exp", exp(args - sigma)),
#
# y.max=replace(y.max, type=="log", K / ( 1 + (K-1)*exp(args) ) + sigma),
# y.max=replace(y.max, type=="exp", exp(args + sigma)) ) %>%
#
# # dplyr::mutate(args=replace(args, type=="log", -rate*(x-init_t)),
# # args=replace(args, type=="exp", rate*(x-init_t)),
# #
# # y=replace(y, type=="log", K / ( 1 + (K-1)*exp(args) ) ),
# # y=replace(y, type=="exp", exp(args)),
# #
# # y.min=replace(y.min, type=="log", max(0, K / ( 1 + (K-1)*exp(args) ) - sigma)),
# # y.min=replace(y.min, type=="exp", exp(args - sigma)),
# #
# # y.max=replace(y.max, type=="log", K / ( 1 + (K-1)*exp(args) ) + sigma),
# # y.max=replace(y.max, type=="exp", exp(args + sigma)) ) %>%
# dplyr::ungroup() %>%
#
# dplyr::select(Lineage, Identity, type, best_model, init_t, x, y, y.min, y.max) %>%
# dplyr::mutate(type=ifelse(type=="log", "Logistic", "Exponential")) %>%
# dplyr::mutate(best_model=ifelse(best_model=="log", "Logistic", "Exponential")) %>%
# dplyr::mutate(Identity=factor(Identity, levels=highlight))

return(regr_df)
}

Expand Down
126 changes: 93 additions & 33 deletions R/utils_growth_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,12 @@ get_growth_params = function(timepoints_to_int,
lineages,
cluster,
rates.exp=NULL,
rates.log=NULL) {
rates.log=NULL,
posterior_samples.exp=NULL,
posterior_samples.log=NULL) {

if (!is.null(rates.exp)) params.exp = get_growth_rates_exp(rates.exp, lineages, cluster, timepoints_to_int)
if (!is.null(rates.log)) params.log = get_growth_rates_log(rates.log, lineages, cluster, timepoints_to_int)
if (!is.null(rates.exp)) params.exp = get_growth_rates_exp(rates.exp, lineages, cluster, timepoints_to_int, posterior_samples.exp)
if (!is.null(rates.log)) params.log = get_growth_rates_log(rates.log, lineages, cluster, timepoints_to_int, posterior_samples.log)

if (is.null(rates.exp)) return(params.log)
if (is.null(rates.log)) return(params.exp)
Expand All @@ -72,22 +74,55 @@ get_growth_params = function(timepoints_to_int,
}


get_growth_rates_exp = function(rates.df, lineages, cluster, timepoints_to_int) {
pars = data.frame() %>%
posterior_samples_to_df = function(numpy_array, lineages) {
# numpy array of dimension N_iters x L
# returns a dataframe with nrow = L and one column with a list of samples

# initialize the columns
tibble::add_column("Lineage"=as.character(NA),
"fitness.exp"=NA,
"init_t.exp"=NA,
"p_rate.exp"=NA,
"sigma.exp"=NA,
"rate.exp"=as.numeric(NA)) %>%
if (is.null(numpy_array))
return(data.frame(Lineage=lineages) %>% dplyr::mutate(param_samples=NA))

# add the values
tibble::add_row(Lineage=lineages,
p_rate.exp=rates.df$parent_rate,
fitness.exp=rates.df$fitness %>% as.numeric(),
init_t.exp=rates.df$init_time %>% as.integer())
numpy_array = numpy_array %>% as.matrix() %>% t() %>% as.data.frame()
rownames(numpy_array) = lineages
result = numpy_array %>% tibble::rownames_to_column(var="Lineage") %>%
tidyr::pivot_longer(cols=-"Lineage") %>%
dplyr::group_by(Lineage) %>%
dplyr::summarise(param_samples=list(value %>% setNames(name)))
}


get_growth_rates_exp = function(rates.df, lineages, cluster, timepoints_to_int,
posterior_samples) {

posterior_df = data.frame()
if (!is.null(posterior_samples)) {
posterior_df = posterior_samples$fitness %>% posterior_samples_to_df(lineages) %>%
dplyr::rename(post_fitness=param_samples) %>%
dplyr::left_join(
posterior_samples$init_time %>% posterior_samples_to_df(lineages) %>%
dplyr::rename(post_init_time=param_samples),
by="Lineage") %>%
dplyr::group_by(Lineage) %>%
tidyr::nest(posterior_samples.exp=c(post_fitness, post_init_time))
}

sigma = rates.df$sigma %>% as.matrix() %>% t() %>% as.data.frame()
colnames(sigma) = c(0, timepoints_to_int)
rownames(sigma) = lineages

sigma = sigma %>% tibble::rownames_to_column(var="Lineage") %>%
tidyr::pivot_longer(cols=-"Lineage") %>%
dplyr::group_by(Lineage) %>%
dplyr::summarise(sigma.exp=list(value %>% setNames(name)))

pars = data.frame(
Lineage=lineages,
# p_rate.exp=rates.df$parent_rate,
fitness.exp=rates.df$fitness %>% as.numeric(),
init_t.exp=rates.df$init_time %>% as.integer()
) %>%
dplyr::mutate(rate.exp=NA,
p_rate.exp=ifelse(is.null(rates.df$parent_rate), NA, rates.df$parent_rate)) %>%
dplyr::inner_join(sigma, by="Lineage")

try(expr = {
pars = pars %>% dplyr::mutate(p_rate.exp=as.numeric(p_rate.exp))
Expand All @@ -99,29 +134,53 @@ get_growth_rates_exp = function(rates.df, lineages, cluster, timepoints_to_int)
# compute the rates for subclones
dplyr::mutate(rate.exp=replace( rate.exp, !is.na(p_rate.exp), p_rate.exp * (1+fitness.exp) ),
rate.exp=replace( rate.exp, is.na(p_rate.exp), fitness.exp ),
sigma.exp=list( setNames(object=rates.df$sigma, nm=c(0, unlist(timepoints_to_int))) ),
# sigma.exp=list( setNames(object=rates.df$sigma, nm=c(0, unlist(timepoints_to_int))) ),
Identity=cluster) %>%

dplyr::left_join(posterior_df) %>%

tibble::as_tibble()
)
}


get_growth_rates_log = function(rates.df, lineages, cluster, timepoints_to_int) {
pars = data.frame() %>%
tibble::add_column("Lineage"=as.character(NA),
"fitness.log"=NA,
"K.log"=NA,
"init_t.log"=NA,
"p_rate.log"=NA,
"sigma.log"=NA,
"rate.log"=as.numeric(NA)) %>%
get_growth_rates_log = function(rates.df, lineages, cluster, timepoints_to_int,
posterior_samples) {

posterior_df = data.frame()
if (!is.null(posterior_samples)) {
posterior_df = posterior_samples$fitness %>% posterior_samples_to_df(lineages) %>%
dplyr::rename(post_fitness=param_samples) %>%
dplyr::left_join(
posterior_samples$init_time %>% posterior_samples_to_df(lineages) %>%
dplyr::rename(post_init_time=param_samples),
by="Lineage") %>%
dplyr::left_join(
posterior_samples$carr_capac %>% posterior_samples_to_df(lineages) %>%
dplyr::rename(post_carr_capacity=param_samples),
by="Lineage") %>%
dplyr::group_by(Lineage) %>%
tidyr::nest(posterior_samples.log=c(post_fitness, post_init_time, post_carr_capacity))
}

tibble::add_row(fitness.log=rates.df$fitness %>% as.numeric(),
K.log=rates.df$carr_capac %>% as.integer(),
init_t.log=rates.df$init_time %>% as.integer(),
Lineage=lineages,
p_rate.log=rates.df$parent_rate)
sigma = rates.df$sigma %>% as.matrix() %>% t() %>% as.data.frame()
colnames(sigma) = c(0, timepoints_to_int)
rownames(sigma) = lineages

sigma = sigma %>% tibble::rownames_to_column(var="Lineage") %>%
tidyr::pivot_longer(cols=-"Lineage") %>%
dplyr::group_by(Lineage) %>%
dplyr::summarise(sigma.log=list(value %>% setNames(name)))

pars = data.frame(
Lineage=lineages,
fitness.log=rates.df$fitness %>% as.numeric(),
K.log=rates.df$carr_capac %>% as.integer(),
init_t.log=rates.df$init_time %>% as.integer()
) %>%
dplyr::mutate(rate.log=NA,
p_rate.log=ifelse(is.null(rates.df$parent_rate), NA, rates.df$parent_rate)) %>%
dplyr::inner_join(sigma, by="Lineage")

try(expr = {
pars = pars %>% dplyr::mutate(p_rate.log=as.numeric(p_rate.log))
Expand All @@ -132,10 +191,11 @@ get_growth_rates_log = function(rates.df, lineages, cluster, timepoints_to_int)

dplyr::mutate(rate.log=replace( rate.log, !is.na(p_rate.log), p_rate.log * (1+fitness.log) ),
rate.log=replace( rate.log, is.na(p_rate.log), fitness.log ),
sigma.log=list( setNames(object=rates.df$sigma, nm=c(0, unlist(timepoints_to_int))) ),
Identity=cluster
) %>%

dplyr::left_join(posterior_df) %>%

tibble::as_tibble()
)
}
Expand Down

0 comments on commit 9d4d29a

Please sign in to comment.