Skip to content

Commit

Permalink
Update generate_synthetic_df.R
Browse files Browse the repository at this point in the history
  • Loading branch information
elena-buscaroli committed Jan 9, 2025
1 parent d45543d commit 91aea10
Showing 1 changed file with 18 additions and 12 deletions.
30 changes: 18 additions & 12 deletions R/generate_synthetic_df.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ generate_synthetic_df = function(N_values,
filename="",
check_present=T,
default_lm=FALSE,
run=T) {
run=T, py_pkg=NULL) {

if (!endsWith(path, "/"))
path = paste0(path, "/")
Expand All @@ -22,7 +22,7 @@ generate_synthetic_df = function(N_values,
dir.create(path)

torch = reticulate::import("torch")
py_pkg = reticulate::import("pylineaGT")
if (is.null(py_pkg)) py_pkg = reticulate::import("pylineaGT")

mean_scale_inp = mean_scale
alpha_inp = alpha
Expand All @@ -38,16 +38,16 @@ generate_synthetic_df = function(N_values,
for (tt in T_values) {
for (kk in K_values) {

if (tt == 1 && kk > 6) {
mean_scale = max(10100, mean_scale_inp)
alpha = max(0.45, alpha_inp)
} else if ((tt == 2 && kk > 6) || (tt == 1 && kk == 6)) {
mean_scale = max(8000, mean_scale_inp)
alpha = max(0.35, alpha_inp)
} else {
mean_scale = mean_scale_inp
alpha = alpha_inp
}
# if (tt == 1 && kk > 6) {
# mean_scale = max(10100, mean_scale_inp)
# # alpha = max(0.45, alpha_inp)
# } else if ((tt == 2 && kk > 6) || (tt == 1 && kk == 6)) {
# mean_scale = max(8000, mean_scale_inp)
# # alpha = max(0.35, alpha_inp)
# } else {
# mean_scale = mean_scale_inp
# # alpha = alpha_inp
# }

cat(paste0("K=", kk, ", T=", tt, "\nmean_scale=", mean_scale, ", alpha=", alpha, "\n"))

Expand Down Expand Up @@ -91,6 +91,8 @@ generate_synthetic_df = function(N_values,

k_interval = get_sim_k_interval(x, cov.df)

start_time = Sys.time()

x_fit = fit(cov.df=cov.df,
k_interval=k_interval,
infer_growth=F,
Expand All @@ -106,10 +108,14 @@ generate_synthetic_df = function(N_values,
# init_seed=5,
sample_id=x$sim_id)

end_time = Sys.time()

x_fit$cov.dataframe = tibble::as_tibble(x$dataset) %>%
dplyr::mutate(coverage=as.integer(coverage)) %>%
dplyr::inner_join(x_fit$cov.dataframe, by=c("IS","timepoints","lineage","coverage"))

x_fit$time = end_time - start_time

print(aricode::ARI(x_fit$cov.dataframe$labels, x_fit$cov.dataframe$labels_true))

saveRDS(x_fit, paste0(subpath, filename, ".fit.Rds"))
Expand Down

0 comments on commit 91aea10

Please sign in to comment.