Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,4 @@ docs/notebooks/*.pt
/.hatch/

src/perturbo/simulation/.ipynb_checkpoints/
lightning_logs/
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ authors = [{ name = "Logan Blaine" }]
maintainers = [{ name = "Logan Blaine", email = "[email protected]" }]
dependencies = [
"anndata>=0.10.9",
"scvi-tools>=1.1.2",
"scvi-tools>=1.4.1",
"torch>=1.11.0",
"pyro-ppl>=1.8.6",
"mudata>=0.3.1",
"pandas>=1.5.3",
"pandas>=2.0.3",
"scipy>=1.11.4",
"matplotlib",
"numpy",
Expand Down Expand Up @@ -52,7 +52,7 @@ optional-dependencies.doc = [
"sphinxext-opengraph",
]

optional-dependencies.test = ["pytest", "coverage"]
optional-dependencies.test = ["pytest", "coverage>=7.6.1"]

# https://docs.pypi.org/project_metadata/#project-urls
urls.Documentation = "https://PerTurbo.readthedocs.io/"
Expand Down
2 changes: 1 addition & 1 deletion src/perturbo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from importlib.metadata import version

from . import models, simulation
from . import models, simulation, utils
from .models import PERTURBO
from .simulation import Learn_Data, Simulate_Data

Expand Down
485 changes: 427 additions & 58 deletions src/perturbo/models/_model.py

Large diffs are not rendered by default.

99 changes: 69 additions & 30 deletions src/perturbo/models/_module.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import warnings
from collections.abc import Mapping
from typing import Literal
Expand All @@ -6,7 +7,7 @@
import pyro.distributions as dist
import torch
from pyro import poutine
from pyro.infer.autoguide import AutoDelta, AutoGuideList, AutoNormal, init_to_median
from pyro.infer.autoguide import AutoDelta, AutoGuideList, AutoNormal, init_to_median, init_to_value
from scvi.module.base import PyroBaseModuleClass

from ._constants import REGISTRY_KEYS
Expand All @@ -31,10 +32,11 @@ def __init__(
n_batches: int | None = 1,
log_gene_mean_init: torch.Tensor | None = None,
log_gene_dispersion_init: torch.Tensor | None = None,
lfc_init: torch.Tensor | None = None,
guide_by_element: torch.Tensor | None = None,
gene_by_element: torch.Tensor | None = None,
likelihood: Literal["nb", "lnnb"] = "nb",
effect_prior_dist: Literal["cauchy", "normal_mixture", "normal", "laplace"] = "laplace",
effect_prior_dist: Literal["cauchy", "normal_mixture", "normal", "laplace"] = "cauchy",
n_factors: int | None = None,
n_pert_factors: int | None = None,
efficiency_mode: Literal["mixture", "scaled", "mixture_high_moi"] | None = "scaled",
Expand Down Expand Up @@ -150,35 +152,18 @@ def __init__(
self.n_batches = n_batches

# Sites to approximate with Delta distribution instead of default Normal distribution.
self.delta_sites = []
# self.delta_sites = []
self.delta_sites = ["log_gene_mean", "log_gene_dispersion", "multiplicative_noise"]
# self.delta_sites = ["cell_factors"]
# self.delta_sites = ["cell_factors", "cell_loadings", "pert_factors", "pert_loadings"]

if log_gene_mean_init is None:
log_gene_mean_init = torch.zeros(self.n_genes)
self.log_gene_mean_init = log_gene_mean_init

if log_gene_dispersion_init is None:
log_gene_dispersion_init = torch.ones(self.n_genes)

# if control_pcs is not None and n_factors is not None:
# init_values["cell_loadings"] = control_pcs

self._guide = AutoGuideList(self.model, create_plates=self.create_plates)

self._guide.append(
AutoNormal(
poutine.block(self.model, hide=self.delta_sites + self.discrete_sites),
init_loc_fn=lambda x: init_to_median(x, num_samples=100),
),
)

if self.delta_sites:
self._guide.append(
AutoDelta(
poutine.block(self.model, expose=self.delta_sites),
init_loc_fn=lambda x: init_to_median(x, num_samples=100),
)
)
self.log_gene_dispersion_init = log_gene_dispersion_init

## register hyperparameters as buffers so they get automatically moved to GPU by scvi-tools

Expand Down Expand Up @@ -208,11 +193,14 @@ def __init__(
self.register_buffer("one", torch.tensor(1.0))

# per-gene hyperparams
self.register_buffer("gene_mean_prior_loc", log_gene_mean_init)
self.register_buffer("gene_disp_prior_loc", log_gene_dispersion_init)
self.register_buffer("gene_mean_prior_loc", torch.tensor(0.0))
self.register_buffer("gene_disp_prior_loc", torch.tensor(1.0))

self.register_buffer("gene_mean_prior_scale", torch.tensor(3.0))
self.register_buffer("gene_disp_prior_scale", torch.tensor(1.0))

self.register_buffer("gene_mean_prior_scale", torch.tensor(0.2))
self.register_buffer("gene_disp_prior_scale", torch.tensor(0.2))
self.register_buffer("noise_prior_loc", torch.tensor(-1.0))
self.register_buffer("noise_prior_scale", torch.tensor(0.5))

# batch/covariate hyperparams
self.register_buffer("batch_effect_prior_scale", torch.tensor(0.2))
Expand Down Expand Up @@ -243,8 +231,27 @@ def __init__(
self.register_buffer("pert_factor_prior_scale", torch.tensor(0.1))
self.register_buffer("pert_loading_prior_scale", torch.tensor(1.0))

# for LogNormalNegativeBinomial likelihood hyperparams
self.register_buffer("noise_prior_rate", torch.tensor(2.0))
# create guide with initial values
if lfc_init is None:
lfc_init = torch.zeros((self.n_elements, self.n_genes))

if self.local_effects and self.sparse_tensors:
if lfc_init.shape == (self.n_elements, self.n_genes):
lfc_init = lfc_init[self.element_by_gene_idx[0], self.element_by_gene_idx[1]]
assert lfc_init.shape != (self.n_element_effects,), (
f"lfc_init shape: {lfc_init.shape}, expected ({self.n_element_effects},)"
)
self.lfc_init = lfc_init

# self._guide = self._guide_factory(self.model)
# self._guide = self._guide_factory(
# self.model,
# init_values={
# "log_gene_mean": self.log_gene_mean_init,
# "log_gene_dispersion": self.log_gene_dispersion_init,
# "element_effects": self.lfc_init,
# },
# )

# override with user-provided values from prior_param_dict
if prior_param_dict is not None:
Expand All @@ -254,6 +261,34 @@ def __init__(
assert v.shape == self.get_buffer(k).shape
self.register_buffer(k, v)

def _guide_factory(self, model, init_values=None, init_scale=0.05):
guide = AutoGuideList(model, create_plates=self.create_plates)
if init_values is None:
init_values = {}
# init_values = {
# "log_gene_mean": self.log_gene_mean_init,
# "log_gene_dispersion": self.log_gene_dispersion_init,
# "element_effects": self.lfc_init,
# }
init_loc_fn = functools.partial(init_to_value, values=init_values, fallback=init_to_median(num_samples=100))

guide.append(
AutoNormal(
poutine.block(model, hide=self.delta_sites + self.discrete_sites),
init_loc_fn=init_loc_fn,
init_scale=init_scale,
),
)

if self.delta_sites:
guide.append(
AutoDelta(
poutine.block(model, expose=self.delta_sites),
init_loc_fn=init_loc_fn,
)
)
return guide

@staticmethod
def _get_fn_args_from_batch(tensor_dict: dict) -> tuple[tuple[torch.Tensor], dict]:
fit_size_factor_covariate = False
Expand Down Expand Up @@ -547,7 +582,10 @@ def model(self, idx: torch.Tensor, **tensor_dict) -> None:

if self.likelihood == "lnnb":
# additional noise for LogNormalNegativeBinomial likelihood
multiplicative_noise = pyro.sample("multiplicative_noise", dist.Exponential(self.noise_prior_rate))
# multiplicative_noise = pyro.sample("multiplicative_noise", dist.LogNormal(self.noise_prior_rate))
multiplicative_noise = pyro.sample(
"multiplicative_noise", dist.LogNormal(self.noise_prior_loc, self.noise_prior_scale)
)
# multiplicative_noise = 1 / self.noise_prior_rate

with batch_plate:
Expand Down Expand Up @@ -618,6 +656,7 @@ def model(self, idx: torch.Tensor, **tensor_dict) -> None:
"obs",
LogNormalNegativeBinomial(
logits=nb_log_mean - nb_log_dispersion - multiplicative_noise**2 / 2,
# logits=nb_log_mean - nb_log_dispersion,
total_count=nb_log_dispersion.exp(),
multiplicative_noise_scale=multiplicative_noise,
num_quad_points=self.lnnb_quad_points,
Expand Down
1 change: 1 addition & 0 deletions src/perturbo/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .utils import empirical_pvals_from_null, compute_empirical_pvals
Loading
Loading