Skip to content

GP - maybe this time #105

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cellij/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
from .models import MOFA
from .synthetic import DataGenerator
from .utils import EarlyStopper, KNNImputer, logger, set_all_seeds
# from ._priors import PriorDist, InverseGammaPrior, NormalPrior, GaussianProcessPrior, LaplacePrior, HorseshoePrior, SpikeAndSlabPrior
from ._gp import DenseGP
9 changes: 5 additions & 4 deletions cellij/core/_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
import torch


class PseudotimeGP(gpytorch.models.ApproximateGP):
class DenseGP(gpytorch.models.ApproximateGP):
def __init__(
self,
inducing_points: torch.Tensor,
covariates: torch.Tensor,
n_factors: int,
init_lengthscale=5.0,
) -> None:
n_inducing = len(inducing_points)
print(covariates.dtype)
n_inducing = len(covariates)

variational_distribution = gpytorch.variational.CholeskyVariationalDistribution(
num_inducing_points=n_inducing,
Expand All @@ -18,7 +19,7 @@ def __init__(

variational_strategy = gpytorch.variational.VariationalStrategy(
model=self,
inducing_points=inducing_points,
inducing_points=covariates,
variational_distribution=variational_distribution,
learn_inducing_locations=False,
)
Expand Down
24 changes: 14 additions & 10 deletions cellij/core/_pyro_guides.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,13 +387,12 @@ def _get_q_dists(self, priors: Dict[str, PriorDist]) -> Dict[str, QDist]:
for group, prior in priors.items():
# Replace strings with actual Q distributions
_q_dists[group] = {
"InverseGammaP": InverseGammaQ,
"NormalP": NormalQ,
"GaussianProcessQ": GaussianProcessQ,
"LaplaceP": LaplaceQ,
"NonnegativeP": NonnegativeQ,
"HorseshoeP": HorseshoeQ,
"SpikeAndSlabP": SpikeAndSlabQ,
"InverseGammaPrior": InverseGammaQ,
"NormalPrior": NormalQ,
"GaussianProcessPrior": GaussianProcessQ,
"LaplacePrior": LaplaceQ,
"HorseshoePrior": HorseshoeQ,
"SpikeAndSlabPrior": SpikeAndSlabQ,
}[prior._pyro_name](prior=prior)

return _q_dists
Expand All @@ -411,9 +410,14 @@ def forward(
with plates["factor"]:
factor_q_dist.sample_inter()
with plates[f"obs_{obs_group}"]:
self.sample_dict[
self.model.factor_priors[obs_group].site_name
] = factor_q_dist()
if covariate is not None:
self.sample_dict[
self.model.factor_priors[obs_group].site_name
] = factor_q_dist(covariate)
else:
self.sample_dict[
self.model.factor_priors[obs_group].site_name
] = factor_q_dist()

for feature_group, weight_q_dist in self.weight_q_dists.items():
weight_q_dist.sample_global()
Expand Down
16 changes: 14 additions & 2 deletions cellij/core/_pyro_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
from pyro.nn import PyroModule
from torch.types import _device
import pandas as pd

from cellij.core._pyro_priors import PRIOR_MAP, PriorDist

Expand All @@ -22,6 +23,7 @@ def __init__(
factor_priors: dict[str, str],
weight_priors: dict[str, str],
device: _device,
covariates: Optional[pd.DataFrame] = None,
):
"""Instantiate a generative model for the multi-group and multi-view FA.

Expand Down Expand Up @@ -51,6 +53,7 @@ def __init__(
self.n_feature_groups = len(feature_dict)
self.n_obs_groups = len(obs_dict)
self.likelihoods = likelihoods
self.covariates = covariates

self.device = device
self.to(self.device)
Expand Down Expand Up @@ -86,6 +89,12 @@ def _get_prior(
)

prior = PRIOR_MAP[prior_config["name"]]

# if the prior is a GaussianProcess, add n_factors and the covariates
if prior_config["name"] == "GaussianProcess":
prior_config["n_factors"] = self.n_factors
prior_config["covariates"] = self.covariates

prior_config.pop("name")
return prior(site_name=site_name, device=self.device, **prior_config)

Expand Down Expand Up @@ -125,7 +134,7 @@ def forward(
with plates["factor"]:
factor_prior.sample_inter()
with plates[f"obs_{obs_group}"]:
self.sample_dict[f"z_{obs_group}"] = factor_prior(covariate)
self.sample_dict[f"z_{obs_group}"] = factor_prior(self.covariates)

for feature_group, weight_prior in self.weight_priors.items():
weight_prior.sample_global()
Expand All @@ -151,9 +160,12 @@ def forward(
if data is not None:
obs = data[obs_group][feature_group].view(obs_shape)

z = self.sample_dict[f"z_{obs_group}"].view(z_shape)
z = self.sample_dict[f"z_{obs_group}"] # .view(z_shape)
w = self.sample_dict[f"w_{feature_group}"].view(w_shape)

print("z", z.shape)
print("w", w.shape)

loc = torch.einsum("...ikj,...ikj->...ij", z, w).view(obs_shape)

scale = torch.sqrt(self.sample_dict[f"sigma_{feature_group}"]).view(
Expand Down
16 changes: 9 additions & 7 deletions cellij/core/_pyro_priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import torch
from pyro.nn import PyroModule
from torch.types import _device, _size
from cellij.core._gp import DenseGP


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -174,7 +176,7 @@ def __init__(self, site_name: str, device: _device, **kwargs: dict[str, Any]):
device : _device
Torch device
"""
super().__init__("InverseGammaP", site_name, device)
super().__init__("InverseGammaPrior", site_name, device)

def forward(self, *args: Any, **kwargs: dict[str, Any]) -> Optional[torch.Tensor]:
return self._sample(
Expand All @@ -195,7 +197,7 @@ def __init__(self, site_name: str, device: _device, **kwargs: dict[str, Any]):
device : _device
Torch device
"""
super().__init__("NormalP", site_name, device)
super().__init__("NormalPrior", site_name, device)

def forward(self, *args: Any, **kwargs: dict[str, Any]) -> Optional[torch.Tensor]:
return self._sample(
Expand All @@ -216,8 +218,8 @@ def __init__(self, site_name: str, device: _device, **kwargs: dict[str, Any]):
device : _device
Torch device
"""
super().__init__("GaussianProcessP", site_name, device)
self.gp = PseudotimeGPrior(**kwargs)
super().__init__("GaussianProcessPrior", site_name, device)
self.gp = DenseGP(**kwargs)

def forward(self, *args: Any, **kwargs: dict[str, Any]) -> Optional[torch.Tensor]:
covariate = args[0]
Expand Down Expand Up @@ -248,7 +250,7 @@ def __init__(
Scale for the Laplace distribution, smaller leads to sparser solutions,
by default 0.1
"""
super().__init__("LaplaceP", site_name, device)
super().__init__("LaplacePrior", site_name, device)
self.scale = self._const(scale)

def forward(self, *args: Any, **kwargs: dict[str, Any]) -> Optional[torch.Tensor]:
Expand Down Expand Up @@ -334,7 +336,7 @@ def __init__(
ValueError
If both `tau_scale` and `tau_delta` are specified
"""
super().__init__("HorseshoeP", site_name, device)
super().__init__("HorseshoePrior", site_name, device)

self.tau_site_name = self.site_name + "_tau"
self.thetas_site_name = self.site_name + "_thetas"
Expand Down Expand Up @@ -432,7 +434,7 @@ def __init__(
Whether to sparsify whole components (factors),
by default True
"""
super().__init__("SpikeAndSlabP", site_name, device)
super().__init__("SpikeAndSlabPrior", site_name, device)

self.thetas_site_name = self.site_name + "_thetas"
self.alphas_site_name = self.site_name + "_alphas"
Expand Down
81 changes: 79 additions & 2 deletions cellij/core/factormodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from pyro.infer import SVI
from pyro.nn import PyroModule
from tqdm import trange
import pandas as pd

import cellij
from cellij.core._data import DataContainer
Expand Down Expand Up @@ -106,6 +107,55 @@ def __init__(
# Save kwargs for later
self._kwargs = kwargs

def __repr__(self):

self._init_from_config(
data_options=self._data_options,
model_options=self._model_options,
training_options=self._training_options,
)

output = []
output.append(f"FactorModel(n_factors={self.n_factors})")

output.append("├─ data")
if len(self._data._names) == 0:
output.append("│ └─ no data added yet")
else:
for name, adata in self._data.feature_groups.items():
branch_char = "├" if name != list(self._data.feature_groups.keys())[-1] else "└"
output.append(f"│ {branch_char}─ {name}: {adata.n_obs} observations × {adata.n_vars} features")
output.append(f"│ ├ likelihood: {self._model_options['likelihoods'][name]}")
output.append(f"│ └ weight_prior: {self._model_options['weight_priors'][name]}")

output.append("├─ groups")
if len(self._data._names) == 0:
output.append("│ └─ no data added yet")
else:
for name, obs in self.obs_groups.items():
branch_char = "├" if name != list(self.obs_groups.keys())[-1] else "└"
output.append(f"│ {branch_char}─ {name}: {len(obs)} observations")
output.append(f"│ └ factor_prior: {self._model_options['factor_priors'][name]}")

if self._model_options["covariates"] is not None:
output.append("├─ covariates")

output.append(f"│ └─ {self._model_options['covariates'].shape[1]}D covariate with {self._model_options['covariates'].shape[0]} observations")

output.append("└─ config")

output.append(f" ├─ data options")
for key, value in self._data_options.items():
branch_char = "├" if key != list(self._data_options.keys())[-1] else "└"
output.append(f" │ {branch_char}─ {key}: {value}")

output.append(f" └─ training options")
for key, value in self._training_options.items():
branch_char = "├" if key != list(self._training_options.keys())[-1] else "└"
output.append(f" {branch_char}─ {key}: {value}")

return "\n".join(output)

@property
def n_factors(self):
return self._n_factors
Expand Down Expand Up @@ -290,6 +340,8 @@ def set_model_options(
factor_priors: Optional[Union[str, dict[str, str]]] = None,
weight_priors: Optional[Union[str, dict[str, str]]] = None,
groups: Optional[dict[str, Iterable]] = None,
regress_out: Optional[pd.DataFrame] = None,
covariates: Optional[pd.DataFrame] = None,
preview: bool = False,
) -> Optional[dict[str, Any]]:
if isinstance(likelihoods, str):
Expand Down Expand Up @@ -382,10 +434,29 @@ def set_model_options(
f"Could not find valid prior for '{prior}'."
) from e

groups = self.obs_groups if groups is None else groups

#TODO: Implement logic for `regress_out`

if covariates is not None:
if not isinstance(covariates, (pd.DataFrame, pd.Series)):
raise TypeError(
f"Parameter 'covariates' must be pd.DataFrame, got '{type(covariates)}'."
)

if (covariates.reset_index().duplicated()).any():
raise ValueError(
f"Parameter 'covariates' contains duplicate columns."
)

for group in factor_priors.keys():
factor_priors[group] = "GaussianProcess"

options = {
"likelihoods": likelihoods,
"factor_priors": factor_priors,
"weight_priors": weight_priors,
"covariates": covariates,
"groups": groups,
}

Expand Down Expand Up @@ -858,13 +929,19 @@ def fit(
f"{name}": len(var_names) for name, var_names in self.feature_groups.items()
}

if self._model_options["covariates"] is not None:
covariates = torch.tensor(self._model_options["covariates"].values, dtype=self._dtype)
else:
covariates = None

model = Generative(
n_factors=self.n_factors,
obs_dict=obs_dict,
feature_dict=feature_dict,
likelihoods=self._model_options["likelihoods"],
factor_priors=self._model_options["factor_priors"],
weight_priors=self._model_options["weight_priors"],
covariates=covariates,
device=self.device,
)
guide = Guide(model)
Expand Down Expand Up @@ -952,11 +1029,11 @@ def fit(
) as pbar:
pbar.set_description("Training")
for i in pbar:
loss = svi.step(data=data)
loss = svi.step(data=data, covariate=covariates)
self.train_loss_elbo.append(loss)

if self._training_options["early_stopping"] and earlystopper.step(loss):
logger.warning(
logger.info(
f"Early stopping of training due to convergence at step {i}"
)
break
Expand Down