Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Release 0.1.3 #17

Merged
merged 6 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ repos:
hooks:
- id: prettier
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.7.3
rev: v0.9.6
hooks:
- id: ruff
types_or: [python, pyi, jupyter]
Expand Down
12 changes: 9 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,17 @@

-

## [0.1.3] - 2025-02-12

- Introduce mean activation to make non-negative latents possible (docs will come later)
- Better communication when Merlin is not installed
- Raise error when interpretability is called on model with continues covariates

## [0.1.2] - 2024-11-11

- No change in DRVI code
- Fix github workflow, tests, docs, and pypi publishing pipelines
- No change in DRVI code
- Fix github workflow, tests, docs, and pypi publishing pipelines

## [0.1.0] - 2024-08-21

- Moved all files from repo to scverse cookiecutter project template
- Moved all files from repo to scverse cookiecutter project template
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ Unsupervised Deep Disentangled Representation of Single-Cell Omics

Please refer to the [documentation][link-docs]. In particular, the

- [Tutorials][link-tutorials], specially
- [A demo](https://drvi.readthedocs.io/latest/notebooks/general_pipeline.html) of how to train DRVI and interpret the latent dimensions.
- [API documentation][link-api], specially
- [DRVI Model](https://drvi.readthedocs.io/latest/api/generated/drvi.model.DRVI.html)
- [DRVI utility functions (tools)](https://drvi.readthedocs.io/latest/api/tools.html)
- [DRVI plotting functions](https://drvi.readthedocs.io/latest/api/plotting.html)
- [Tutorials][link-tutorials], specially
- [A demo](https://drvi.readthedocs.io/latest/notebooks/general_pipeline.html) of how to train DRVI and interpret the latent dimensions.
- [API documentation][link-api], specially
- [DRVI Model](https://drvi.readthedocs.io/latest/api/generated/drvi.model.DRVI.html)
- [DRVI utility functions (tools)](https://drvi.readthedocs.io/latest/api/tools.html)
- [DRVI plotting functions](https://drvi.readthedocs.io/latest/api/plotting.html)

## System requirements

Expand Down
18 changes: 9 additions & 9 deletions docs/contributing.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,11 @@ Specify `vX.X.X` as a tag name and create a release. For more information, see [

Please write documentation for new or changed features and use-cases. This project uses [sphinx][] with the following features:

- the [myst][] extension allows to write documentation in markdown/Markedly Structured Text
- [Numpy-style docstrings][numpydoc] (through the [napoloen][numpydoc-napoleon] extension).
- Jupyter notebooks as tutorials through [myst-nb][] (See [Tutorials with myst-nb](#tutorials-with-myst-nb-and-jupyter-notebooks))
- [Sphinx autodoc typehints][], to automatically reference annotated input and output types
- Citations (like {cite:p}`Virshup_2023`) can be included with [sphinxcontrib-bibtex](https://sphinxcontrib-bibtex.readthedocs.io/)
- the [myst][] extension allows to write documentation in markdown/Markedly Structured Text
- [Numpy-style docstrings][numpydoc] (through the [napoloen][numpydoc-napoleon] extension).
- Jupyter notebooks as tutorials through [myst-nb][] (See [Tutorials with myst-nb](#tutorials-with-myst-nb-and-jupyter-notebooks))
- [Sphinx autodoc typehints][], to automatically reference annotated input and output types
- Citations (like {cite:p}`Virshup_2023`) can be included with [sphinxcontrib-bibtex](https://sphinxcontrib-bibtex.readthedocs.io/)

See the [scanpy developer docs](https://scanpy.readthedocs.io/en/latest/dev/documentation.html) for more information
on how to write documentation.
Expand All @@ -121,10 +121,10 @@ repository.

#### Hints

- If you refer to objects from other packages, please add an entry to `intersphinx_mapping` in `docs/conf.py`. Only
if you do so can sphinx automatically create a link to the external documentation.
- If building the documentation fails because of a missing link that is outside your control, you can add an entry to
the `nitpick_ignore` list in `docs/conf.py`
- If you refer to objects from other packages, please add an entry to `intersphinx_mapping` in `docs/conf.py`. Only
if you do so can sphinx automatically create a link to the external documentation.
- If building the documentation fails because of a missing link that is outside your control, you can add an entry to
the `nitpick_ignore` list in `docs/conf.py`

#### Building the docs locally

Expand Down
2 changes: 1 addition & 1 deletion docs/extensions/typed_returns.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
def _process_return(lines: Iterable[str]) -> Generator[str, None, None]:
for line in lines:
if m := re.fullmatch(r"(?P<param>\w+)\s+:\s+(?P<type>[\w.]+)", line):
yield f'-{m["param"]} (:class:`~{m["type"]}`)'
yield f"-{m['param']} (:class:`~{m['type']}`)"
else:
yield line

Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ requires = ["hatchling"]

[project]
name = "drvi-py"
version = "0.1.2"
version = "0.1.3"
description = "Disentangled Generative Representation of Single Cell Omics"
readme = "README.md"
requires-python = ">=3.10,<3.13"
Expand Down Expand Up @@ -69,7 +69,7 @@ dev = [
]
doc = [
# Disable for now as nvidia servers return 404
# "merlin-dataloader==23.8.0",
"merlin-dataloader==23.8.0",
"docutils>=0.8,!=0.18.*,!=0.19.*",
"sphinx>=4",
"sphinx-book-theme>=1.0.0",
Expand Down
6 changes: 2 additions & 4 deletions src/drvi/nn_modules/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,7 @@ def __repr__(self):
if self._freeze_hook is None:
return f"Emb({self.num_embeddings}, {self.embedding_dim})"
else:
return (
f"Emb({self.num_embeddings}, {self.embedding_dim} | " f"freeze: {self.n_freeze_x}, {self.n_freeze_y})"
)
return f"Emb({self.num_embeddings}, {self.embedding_dim} | freeze: {self.n_freeze_x}, {self.n_freeze_y})"


class MultiEmbedding(nn.Module):
Expand Down Expand Up @@ -103,7 +101,7 @@ def from_pretrained(cls, feature_embedding_instance):
def load_weights_from_trained_module(self, other, freeze_old=False):
assert len(self.emb_list) >= len(other.emb_list)
if len(self.emb_list) > len(other.emb_list):
logging.warning(f"Extending feature embedding {other} to {self} " f"with more feature categories.")
logging.warning(f"Extending feature embedding {other} to {self} with more feature categories.")
else:
logging.info(f"Extending feature embedding {other} to {self}")
for self_emb, other_emb in zip(self.emb_list, other.emb_list, strict=False):
Expand Down
4 changes: 1 addition & 3 deletions src/drvi/nn_modules/feature_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@ def __init__(self, feature_info_str_list: list[str], axis="var", total_dim=None,
self.axis = axis
if any(fi.dim is None for fi in self.feature_info_list):
if total_dim is None and default_dim is None:
raise ValueError(
f"missing dim in {feature_info_str_list}\n" f"Please provide `total_dim` or `default_dim`"
)
raise ValueError(f"missing dim in {feature_info_str_list}\nPlease provide `total_dim` or `default_dim`")
if total_dim is not None:
self._fill_with_total_dim(total_dim)
if default_dim is not None:
Expand Down
31 changes: 23 additions & 8 deletions src/drvi/scvi_tools_based/merlin_data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,21 @@
import importlib
import logging

logger = logging.getLogger(__name__)

def get_placeholder(name, error_message, allow_init=False):
error_message = error_message + f" Cannot use '{name}'."

class ClassLevelGetAttrMeta(type):
def __getattr__(cls, name):
raise ImportError(error_message)

class LazyNonExistingModulePlaceholder(metaclass=ClassLevelGetAttrMeta):
def __init__(self):
if not allow_init:
raise ImportError(error_message)
super().__init__()

return LazyNonExistingModulePlaceholder


if importlib.util.find_spec("merlin"):
from . import fields
Expand All @@ -10,12 +24,13 @@
from ._data_manager import MerlinDataManager
from ._data_splitter import MerlinDataSplitter
else:
fields = None
MerlinData = None
MerlinTransformedDataLoader = None
MerlinDataManager = None
MerlinDataSplitter = None
logger.warning("Merlin is not installed. To use merline dataloader please install it.")
error_msg = "Merlin is not installed. To use merline dataloader please install it."
fields = get_placeholder("fields", error_msg)
MerlinData = get_placeholder("MerlinData", error_msg)
MerlinTransformedDataLoader = get_placeholder("MerlinTransformedDataLoader", error_msg)
MerlinDataManager = get_placeholder("MerlinDataManager", error_msg)
MerlinDataSplitter = get_placeholder("MerlinDataSplitter", error_msg)


__all__ = [
"MerlinData",
Expand Down
2 changes: 1 addition & 1 deletion src/drvi/scvi_tools_based/model/base/_archesmixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def load_query_data(
raise ValueError("It appears you are loading a model from a different class.")

if _SETUP_ARGS_KEY not in registry:
raise ValueError("Saved model does not contain original setup inputs. " "Cannot load the original setup.")
raise ValueError("Saved model does not contain original setup inputs. Cannot load the original setup.")

cls.setup_anndata(
adata,
Expand Down
11 changes: 8 additions & 3 deletions src/drvi/scvi_tools_based/module/_drvi.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections.abc import Callable, Iterable, Sequence
from collections.abc import Iterable, Sequence
from typing import Literal

import numpy as np
Expand Down Expand Up @@ -83,7 +83,10 @@ class DRVIModule(BaseModuleClass):
prior_init_dataloader
Dataloader constructed to initialize the prior (or maintain in vamp).
var_activation
Callable used to ensure positivity of the variational distributions' variance.
The activation function to ensure positivity of the variatinal distribution. Defaults to "exp".
mean_activation
The activation function at the end of mean encoder. Defaults to "identity".
Possible values are "identity", "relu", "leaky_relu", "leaky_relu_{slope}", "elu", "elu_{min_vaule}".
encoder_layer_factory
A layer Factory instance for build encoder layers
decoder_layer_factory
Expand Down Expand Up @@ -145,7 +148,8 @@ def __init__(
] = "pnb_softmax",
prior: Literal["normal", "gmm_x", "vamp_x"] = "normal",
prior_init_dataloader: DataLoader | None = None,
var_activation: Callable | Literal["exp", "pow2"] = "exp",
var_activation: Literal["exp", "pow2"] = "exp",
mean_activation: str = "identity",
encoder_layer_factory: LayerFactory = None,
decoder_layer_factory: LayerFactory = None,
extra_encoder_kwargs: dict | None = None,
Expand Down Expand Up @@ -201,6 +205,7 @@ def __init__(
affine_batch_norm=affine_batch_norm_encoder,
use_layer_norm=use_layer_norm_encoder,
var_activation=var_activation,
mean_activation=mean_activation,
layer_factory=encoder_layer_factory,
covariate_modeling_strategy=covariate_modeling_strategy,
categorical_covariate_dims=categorical_covariate_dims if self.encode_covariates else [],
Expand Down
33 changes: 26 additions & 7 deletions src/drvi/scvi_tools_based/nn/_base_components.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import collections
import math
from collections.abc import Callable, Iterable, Sequence
from collections.abc import Iterable, Sequence
from typing import Literal

import torch
Expand Down Expand Up @@ -390,8 +390,10 @@ class Encoder(nn.Module):
Minimum value for the variance;
used for numerical stability
var_activation
Callable used to ensure positivity of the variance.
Defaults to :meth:`torch.exp`.
The activation function to ensure positivity of the variance. Defaults to "exp".
mean_activation
The activation function at the end of mean encoder. Defaults to "identity".
Possible values are "identity", "relu", "leaky_relu", "leaky_relu_{slope}", "elu", "elu_{min_vaule}".
layer_factory
A layer Factory instance for building layers
layers_location
Expand Down Expand Up @@ -419,7 +421,8 @@ def __init__(
dropout_rate: float = 0.1,
distribution: str = "normal",
var_eps: float = 1e-4,
var_activation: Callable | Literal["exp", "pow2"] = "exp",
var_activation: Literal["exp", "pow2"] = "exp",
mean_activation: str = "identity",
layer_factory: LayerFactory = None,
covariate_modeling_strategy: Literal[
"one_hot",
Expand Down Expand Up @@ -497,8 +500,24 @@ def __init__(
elif var_activation == "pow2":
self.var_activation = lambda x: torch.pow(x, 2)
else:
assert callable(var_activation)
self.var_activation = var_activation
raise NotImplementedError()

if mean_activation == "identity":
self.mean_activation = nn.Identity()
elif mean_activation == "relu":
self.mean_activation = nn.ReLU()
elif mean_activation.startswith("leaky_relu"):
if mean_activation == "leaky_relu":
mean_activation = "leaky_relu_0.01"
slope = float(mean_activation.split("leaky_relu_")[1])
self.mean_activation = nn.LeakyReLU(negative_slope=slope)
elif mean_activation.startswith("elu"):
if mean_activation == "elu":
mean_activation = "elu_1.0"
alpha = float(mean_activation.split("elu_")[1])
self.mean_activation = nn.ELU(alpha=alpha)
else:
raise NotImplementedError()

def forward(self, x: torch.Tensor, cat_full_tensor: torch.Tensor, cont_full_tensor: torch.Tensor = None):
r"""The forward computation for a single sample.
Expand All @@ -524,7 +543,7 @@ def forward(self, x: torch.Tensor, cat_full_tensor: torch.Tensor, cont_full_tens
x = torch.cat((x, cont_full_tensor), dim=-1)
# Parameters for latent distribution
q = self.encoder(self.input_dropout(x), cat_full_tensor) if self.encoder is not None else x
q_m = self.mean_encoder(q, cat_full_tensor)
q_m = self.mean_activation(self.mean_encoder(q, cat_full_tensor))
q_v = self.var_activation(self.var_encoder(q, cat_full_tensor)) + self.var_eps
dist = Normal(q_m, q_v.sqrt())
latent = self.z_transformation(dist.rsample())
Expand Down
3 changes: 3 additions & 0 deletions src/drvi/utils/tools/interpretability/_latent_traverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ def make_traverse_adata(
else:
cat_vector = None

if model.adata_manager.get_state_registry(scvi.REGISTRY_KEYS.CONT_COVS_KEY):
raise NotImplementedError("Interpretability of models with continuous covariates are not implemented yet.")

# lib size
lib_vector = np.ones(n_samples) * 1e4
lib_vector = lib_vector[span_adata.obs["sample_id"]]
Expand Down
7 changes: 7 additions & 0 deletions tests/drvi_model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,13 @@ def test_simple_integration_latent_splitting(self):
adata, n_latent=32, n_split_latent=8, split_method="split", split_aggregation="max"
)

def test_simple_integration_mean_activation(self):
adata = self.make_test_adata()
self._general_integration_test(adata, n_latent=32, n_split_latent=32, mean_activation="identity")
self._general_integration_test(adata, n_latent=32, n_split_latent=32, mean_activation="relu")
self._general_integration_test(adata, n_latent=32, n_split_latent=32, mean_activation="leaky_relu_0.4")
self._general_integration_test(adata, n_latent=32, n_split_latent=32, mean_activation="elu_0.4")

def test_decoder_reusing(self):
adata = self.make_test_adata()
for reuse_strategy in ["nowhere"]:
Expand Down