Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev #143

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open

Dev #143

Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
change concat_feature_name param in anndata integration
pass extended_feature_name to nf process for integration

add tests for integrate_anndata.py
dannda committed Sep 10, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
commit fba5d0b27d96f6cd0fc4f52a338a67a0f6e9572c
2 changes: 1 addition & 1 deletion .github/workflows/tests-python.yml
Original file line number Diff line number Diff line change
@@ -34,7 +34,7 @@ jobs:
pip install -r ./envs/dev/requirements.txt
cd ./ome-zarr-metadata && pre-commit install && pip install -e . && cd ../
- name: Run tests
run: python -m pytest --cov=bin tests/test_class.py
run: python -m pytest --cov=bin tests/
env:
PYTHONPATH: ./bin
- name: Upload coverage to Codecov
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -29,6 +29,7 @@ wheels/
.installed.cfg
*.egg
MANIFEST
docs/

# PyInstaller
# Usually these files are written by a python script from a template
34 changes: 16 additions & 18 deletions bin/integrate_anndata.py
Original file line number Diff line number Diff line change
@@ -43,7 +43,7 @@ def reindex_anndata(
adata = data
else:
adata = read_anndata(data)
out_filename = out_filename or "concat-{}".format(
out_filename = out_filename or "reindexed-{}".format(
os.path.splitext(os.path.basename(data))[0]
)

@@ -106,7 +106,7 @@ def concat_matrix_from_obs(
data: Union[ad.AnnData, str],
obs: str = "celltype",
feature_name: str = "gene",
obs_feature_name: str = None,
concat_feature_name: str = None,
):
if isinstance(data, ad.AnnData):
adata = data
@@ -115,22 +115,22 @@ def concat_matrix_from_obs(

ext_matrix = pd.get_dummies(adata.obs[obs], dtype="float32")

return concat_matrices(adata, ext_matrix, obs, feature_name, obs_feature_name)
return concat_matrices(adata, ext_matrix, feature_name, concat_feature_name or obs)


def concat_matrix_from_obsm(
data: Union[ad.AnnData, str],
obsm: str = "celltype",
feature_name: str = "gene",
obsm_feature_name: str = None,
concat_feature_name: str = None,
):
if isinstance(data, ad.AnnData):
adata = data
else:
adata = read_anndata(data)

return concat_matrices(
adata, adata.obsm[obsm], "celltype", feature_name, obsm_feature_name
adata, adata.obsm[obsm], feature_name, concat_feature_name or obsm
)


@@ -140,7 +140,7 @@ def concat_matrix_from_cell2location(
q: str = "q05_cell_abundance_w_sf",
sample: tuple[str, str] = None,
feature_name: str = "gene",
obs_feature_name: str = None,
concat_feature_name: str = "celltype",
sort: bool = True,
sort_index: str = None,
**kwargs,
@@ -202,39 +202,37 @@ def concat_matrix_from_cell2location(
dtype="float32",
)

return concat_matrices(
adata, c2l_df, "celltype", feature_name, obs_feature_name, **kwargs
)
return concat_matrices(adata, c2l_df, feature_name, concat_feature_name, **kwargs)


def concat_matrices(
adata: ad.AnnData,
ext_df: pd.DataFrame,
obs: str = "celltype",
feature_name: str = "gene",
obs_feature_name: str = None,
concat_feature_name: str = "celltype",
):
assert adata.shape[0] == ext_df.shape[0]

obs_feature_name = obs_feature_name or obs
prev_features_bool = "is_{}".format(feature_name)
new_features_bool = "is_{}".format(obs_feature_name)
new_features_bool = "is_{}".format(concat_feature_name)

if isinstance(adata.X, spmatrix):
adata_concat = ad.AnnData(
hstack(
(
adata.X,
csr_matrix(ext_df.values)
if isinstance(adata.X, csr_matrix)
else csc_matrix(ext_df.values),
(
csr_matrix(ext_df.values)
if isinstance(adata.X, csr_matrix)
else csc_matrix(ext_df.values)
),
)
),
obs=adata.obs,
var=pd.concat(
[
adata.var.assign(**{prev_features_bool: True}),
ext_df.columns.to_frame(obs_feature_name)
ext_df.columns.to_frame(concat_feature_name)
.drop(columns=0)
.assign(**{new_features_bool: True}),
]
@@ -249,7 +247,7 @@ def concat_matrices(
var=pd.concat(
[
adata.var.assign(**{prev_features_bool: True}),
ext_df.columns.to_frame(obs_feature_name)
ext_df.columns.to_frame(concat_feature_name)
.drop(columns=0)
.assign(**{new_features_bool: True}),
]
8 changes: 7 additions & 1 deletion multimodal.nf
Original file line number Diff line number Diff line change
@@ -97,6 +97,9 @@ process process_anndata {
features_str = features
? "--features ${features_file.name != 'NO_FT' ? features_file : features}"
: ""
feature_name_str = config_map.extend_feature_name
? "--concat_feature_name ${config_map.extend_feature_name}"
: ""
args_str = features_file.name != 'NO_FT' && features_args
? "--args '" + new JsonBuilder(features_args).toString() + "'" : ""
"""
@@ -141,12 +144,15 @@ process Build_multimodal_config {

script:
url_str = config_map.url?.trim() ? "--url \"${config_map.url.trim()}\"" : ""
extended_features_str = config_map.extend_feature_name
? "--extended_features ${config_map.extend_feature_name}"
: ""
datasets_str = new JsonBuilder(datasets).toString()
"""
build_config_multimodal.py \
--project "${project}" \
--datasets '${datasets_str}' \
--extended_features "${config_map.extend_feature_name}" \
${extended_features_str} \
${url_str} \
--title "${config_map.title}" \
--description "${config_map.description}"
17 changes: 8 additions & 9 deletions tests/test_class.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
import pytest

import os
import csv
import json
import os
from xml.etree.ElementTree import ElementTree

import xmlschema
import zarr
import anndata as ad
import numpy as np
import pandas as pd
import anndata as ad
import pytest
import xmlschema
import zarr
from scipy.sparse import csc_matrix, csr_matrix

from bin.consolidate_md import consolidate
from bin.generate_image import create_img
from bin.ome_zarr_metadata import get_metadata
from bin.process_h5ad import h5ad_to_zarr
from bin.process_molecules import tsv_to_json
from bin.consolidate_md import consolidate
from bin.router import process
from bin.ome_zarr_metadata import get_metadata
from bin.generate_image import create_img

# from bin.build_config_multimodal import write_json as write_json_multimodal
# from bin.build_config import write_json
337 changes: 186 additions & 151 deletions tests/test_class_multimodal.py
Original file line number Diff line number Diff line change
@@ -1,167 +1,202 @@
import operator
import os
from functools import reduce

from bin.build_config_multimodal import write_json as write_json_multimodal
import anndata as ad
import numpy as np
import pandas as pd
import pytest

from bin.integrate_anndata import (
concat_features,
concat_matrix_from_obs,
concat_matrix_from_obsm,
)

def iss_dataset(name="iss_dataset"):
dataset = {
f"{name}": {
"file_paths": [f"test_project-{name}-anndata.zarr"],
"images": {
"raw": [
{
"path": f"/path/to/iss/{name}-raw-image.zarr",
"md": {
"dimOrder": "XYZT",
"channel_names": ["Channel_1"],
"X": 10,
"Y": 10,
"Z": 1,
"C": 1,
"T": 0,
},
}
],
"label": [
{
"path": f"/path/to/iss/{name}-label-image.zarr",
"md": {
"dimOrder": "XYZT",
"channel_names": ["Channel_1"],
"X": 10,
"Y": 10,
"Z": 1,
"C": 1,
"T": 0,
},
}
],
},
"options": {
"matrix": "X",
"factors": ["obs/sample"],
"mappings": {"obsm/X_umap": [0, 1]},
"sets": ["obs/cluster", "obs/celltype"],
"spatial": {"xy": "obsm/spatial"},
},
"obs_type": "cell",
"is_spatial": True,
}
}
return dataset

class TestClass:
@pytest.fixture(scope="class")
def anndata_with_celltype_obs(self, tmp_path_factory):
adata = ad.AnnData(
np.array([[100.0] * 4] * 3),
obs=pd.DataFrame(
index=["obs1", "obs2", "obs3"],
data={"obs1": ["celltype1", "celltype2", "celltype3"]},
),
var=pd.DataFrame(index=["var1", "var2", "var3", "var4"]),
dtype="float32",
)
fn = tmp_path_factory.mktemp("data") / "anndata.h5ad"
adata.write_h5ad(fn)
return fn

def visium_dataset(name="visium_dataset"):
dataset = {
f"{name}": {
"file_paths": [f"test_project-{name}-anndata.zarr"],
"images": {
"raw": [
{
"path": f"/path/to/visium/{name}-raw-image.zarr",
"md": {
"dimOrder": "XYZT",
"channel_names": ["Channel_1"],
"X": 10,
"Y": 10,
"Z": 1,
"C": 1,
"T": 0,
},
}
],
"label": [
{
"path": f"/path/to/visium/{name}-label-image.zarr",
"md": {
"dimOrder": "XYZT",
"channel_names": ["Channel_1"],
"X": 10,
"Y": 10,
"Z": 1,
"C": 1,
"T": 0,
},
}
],
},
"options": {
"matrix": "X",
"factors": ["obs/sample"],
"mappings": {"obsm/X_umap": [0, 1]},
"spatial": {"xy": "obsm/spatial"},
@pytest.fixture(scope="class")
def anndata_with_celltype_obsm(self, tmp_path_factory):
adata = ad.AnnData(
np.array([[100.0] * 4] * 3),
obs=pd.DataFrame(
index=["obs1", "obs2", "obs3"],
),
var=pd.DataFrame(index=["var1", "var2", "var3", "var4"]),
dtype="float32",
)
adata.obsm["obsm1"] = pd.DataFrame(
index=["obs1", "obs2", "obs3"],
data={
"celltype1": [0.1, 0.2, 0.3],
"celltype2": [0.4, 0.5, 0.6],
"celltype3": [0.7, 0.8, 0.9],
},
"obs_type": "spot",
"is_spatial": True,
}
}
return dataset

dtype="float32",
)
fn = tmp_path_factory.mktemp("data") / "anndata.h5ad"
adata.write_h5ad(fn)
return fn

def scrnaseq_dataset(name="scrnaseq_dataset"):
dataset = {
f"{name}": {
"file_paths": [f"test_project-{name}-anndata.zarr"],
"options": {
"matrix": "X",
"factors": ["obs/sample"],
"sets": ["obs/celltype"],
"mappings": {"obsm/X_umap": [0, 1]},
"spatial": {"xy": "obsm/spatial"},
},
"obs_type": "cell",
"is_spatial": False,
}
}
return dataset
def test_concat_matrix_from_obs(self, monkeypatch, anndata_with_celltype_obs):
monkeypatch.chdir(os.path.dirname(anndata_with_celltype_obs))
adata = ad.read(anndata_with_celltype_obs)
adata_concat = concat_matrix_from_obs(adata, "obs1")
assert adata_concat.X.shape == (3, 7)
assert np.array_equal(
adata_concat.X,
np.hstack(
(
np.array([[100.0] * 4] * 3),
np.array([[1.0, 0, 0], [0, 1.0, 0], [0, 0, 1.0]]),
),
),
)
assert all([x in adata_concat.var.columns for x in ["is_gene", "is_obs1"]])
assert adata_concat.var["is_gene"].tolist() == [True] * 4 + [False] * 3
assert adata_concat.var["is_obs1"].tolist() == [False] * 4 + [True] * 3
assert all(adata_concat.var["is_gene"] + adata_concat.var["is_obs1"] == 1)

def test_concat_features_from_obs(self, monkeypatch, anndata_with_celltype_obs):
monkeypatch.chdir(os.path.dirname(anndata_with_celltype_obs))
adata = ad.read(anndata_with_celltype_obs)
adata_concat_1 = concat_matrix_from_obs(adata, "obs1")
adata_concat_2 = concat_features(adata, "obs/obs1")
pd.testing.assert_frame_equal(adata_concat_1.obs, adata_concat_2.obs)
pd.testing.assert_frame_equal(adata_concat_1.var, adata_concat_2.var)
assert np.array_equal(adata_concat_1.X, adata_concat_2.X)

class TestClass:
def test_build_config_multimodal(
self,
def test_concat_matrix_from_obs_with_concat_name(
self, monkeypatch, anndata_with_celltype_obs
):
tests = [
(
"test-iss_visium_sc",
[iss_dataset(), visium_dataset(), scrnaseq_dataset()],
),
(
"test-visium_visium",
[
visium_dataset("visium_dataset_1"),
visium_dataset("visium_dataset_2"),
],
),
(
"test-sc_sc",
[scrnaseq_dataset("sc_dataset_1"), scrnaseq_dataset("sc_dataset_2")],
monkeypatch.chdir(os.path.dirname(anndata_with_celltype_obs))
FEATURE_NAME = "celltype"
adata = ad.read(anndata_with_celltype_obs)
adata_concat = concat_matrix_from_obs(
adata, "obs1", concat_feature_name=FEATURE_NAME
)
assert adata_concat.X.shape == (3, 7)
assert np.array_equal(
adata_concat.X,
np.hstack(
(
np.array([[100.0] * 4] * 3),
np.array([[1.0, 0, 0], [0, 1.0, 0], [0, 0, 1.0]]),
),
),
(
"test-iss_iss_visium",
[
iss_dataset("iss_dataset_1"),
iss_dataset("iss_dataset_2"),
visium_dataset(),
],
)
assert all(
[x in adata_concat.var.columns for x in ["is_gene", f"is_{FEATURE_NAME}"]]
)
assert adata_concat.var["is_gene"].tolist() == [True] * 4 + [False] * 3
assert (
adata_concat.var[f"is_{FEATURE_NAME}"].tolist() == [False] * 4 + [True] * 3
)
assert all(
adata_concat.var["is_gene"] + adata_concat.var[f"is_{FEATURE_NAME}"] == 1
)

def test_concat_features_from_obs_with_concat_name(
self, monkeypatch, anndata_with_celltype_obs
):
monkeypatch.chdir(os.path.dirname(anndata_with_celltype_obs))
FEATURE_NAME = "celltype"
adata = ad.read(anndata_with_celltype_obs)
adata_concat_1 = concat_matrix_from_obs(
adata, "obs1", concat_feature_name=FEATURE_NAME
)
adata_concat_2 = concat_features(
adata, "obs/obs1", concat_feature_name=FEATURE_NAME
)
pd.testing.assert_frame_equal(adata_concat_1.obs, adata_concat_2.obs)
pd.testing.assert_frame_equal(adata_concat_1.var, adata_concat_2.var)
assert np.array_equal(adata_concat_1.X, adata_concat_2.X)

def test_concat_matrix_from_obsm(self, monkeypatch, anndata_with_celltype_obsm):
monkeypatch.chdir(os.path.dirname(anndata_with_celltype_obsm))
adata = ad.read(anndata_with_celltype_obsm)
adata_concat = concat_matrix_from_obsm(adata, "obsm1")
assert adata_concat.X.shape == (3, 7)
assert np.array_equal(
adata_concat.X,
np.hstack(
(
np.array([[100.0] * 4] * 3),
np.array([[0.1, 0.4, 0.7], [0.2, 0.5, 0.8], [0.3, 0.6, 0.9]]),
),
dtype="float32",
),
("test-sc", [scrnaseq_dataset()]),
("test-iss", [iss_dataset()]),
("test-visium", [visium_dataset()]),
]
)
assert all([x in adata_concat.var.columns for x in ["is_gene", "is_obsm1"]])
assert adata_concat.var["is_gene"].tolist() == [True] * 4 + [False] * 3
assert adata_concat.var["is_obsm1"].tolist() == [False] * 4 + [True] * 3
assert all(adata_concat.var["is_gene"] + adata_concat.var["is_obsm1"] == 1)

for test in tests:
input = {
"project": test[0],
"extended_features": "celltype",
"url": "http://localhost/",
"config_filename_suffix": "config.json",
"datasets": reduce(operator.ior, test[1], {}),
}
def test_concat_features_from_obsm(self, monkeypatch, anndata_with_celltype_obsm):
monkeypatch.chdir(os.path.dirname(anndata_with_celltype_obsm))
adata = ad.read(anndata_with_celltype_obsm)
adata_concat_1 = concat_matrix_from_obsm(adata, "obsm1")
adata_concat_2 = concat_features(adata, "obsm/obsm1")
pd.testing.assert_frame_equal(adata_concat_1.obs, adata_concat_2.obs)
pd.testing.assert_frame_equal(adata_concat_1.var, adata_concat_2.var)
assert np.array_equal(adata_concat_1.X, adata_concat_2.X)

write_json_multimodal(**input)
def test_concat_matrix_from_obsm_with_concat_name(
self, monkeypatch, anndata_with_celltype_obsm
):
monkeypatch.chdir(os.path.dirname(anndata_with_celltype_obsm))
FEATURE_NAME = "celltype"
adata = ad.read(anndata_with_celltype_obsm)
adata_concat = concat_matrix_from_obsm(
adata, "obsm1", concat_feature_name=FEATURE_NAME
)
assert adata_concat.X.shape == (3, 7)
assert np.array_equal(
adata_concat.X,
np.hstack(
(
np.array([[100.0] * 4] * 3),
np.array([[0.1, 0.4, 0.7], [0.2, 0.5, 0.8], [0.3, 0.6, 0.9]]),
),
dtype="float32",
),
)
assert all(
[x in adata_concat.var.columns for x in ["is_gene", f"is_{FEATURE_NAME}"]]
)
assert adata_concat.var["is_gene"].tolist() == [True] * 4 + [False] * 3
assert (
adata_concat.var[f"is_{FEATURE_NAME}"].tolist() == [False] * 4 + [True] * 3
)
assert all(
adata_concat.var["is_gene"] + adata_concat.var[f"is_{FEATURE_NAME}"] == 1
)

assert os.path.exists(
f"{input['project']}-multimodal-{input['config_filename_suffix']}"
)
def test_concat_features_from_obsm_with_concat_name(
self, monkeypatch, anndata_with_celltype_obsm
):
monkeypatch.chdir(os.path.dirname(anndata_with_celltype_obsm))
FEATURE_NAME = "celltype"
adata = ad.read(anndata_with_celltype_obsm)
adata_concat_1 = concat_matrix_from_obsm(
adata, "obsm1", concat_feature_name=FEATURE_NAME
)
adata_concat_2 = concat_features(
adata, "obsm/obsm1", concat_feature_name=FEATURE_NAME
)
pd.testing.assert_frame_equal(adata_concat_1.obs, adata_concat_2.obs)
pd.testing.assert_frame_equal(adata_concat_1.var, adata_concat_2.var)
assert np.array_equal(adata_concat_1.X, adata_concat_2.X)
167 changes: 167 additions & 0 deletions tests/test_class_multimodal_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import operator
import os
from functools import reduce

from bin.build_config_multimodal import write_json as write_json_multimodal


def iss_dataset(name="iss_dataset"):
dataset = {
f"{name}": {
"file_paths": [f"test_project-{name}-anndata.zarr"],
"images": {
"raw": [
{
"path": f"/path/to/iss/{name}-raw-image.zarr",
"md": {
"dimOrder": "XYZT",
"channel_names": ["Channel_1"],
"X": 10,
"Y": 10,
"Z": 1,
"C": 1,
"T": 0,
},
}
],
"label": [
{
"path": f"/path/to/iss/{name}-label-image.zarr",
"md": {
"dimOrder": "XYZT",
"channel_names": ["Channel_1"],
"X": 10,
"Y": 10,
"Z": 1,
"C": 1,
"T": 0,
},
}
],
},
"options": {
"matrix": "X",
"factors": ["obs/sample"],
"mappings": {"obsm/X_umap": [0, 1]},
"sets": ["obs/cluster", "obs/celltype"],
"spatial": {"xy": "obsm/spatial"},
},
"obs_type": "cell",
"is_spatial": True,
}
}
return dataset


def visium_dataset(name="visium_dataset"):
dataset = {
f"{name}": {
"file_paths": [f"test_project-{name}-anndata.zarr"],
"images": {
"raw": [
{
"path": f"/path/to/visium/{name}-raw-image.zarr",
"md": {
"dimOrder": "XYZT",
"channel_names": ["Channel_1"],
"X": 10,
"Y": 10,
"Z": 1,
"C": 1,
"T": 0,
},
}
],
"label": [
{
"path": f"/path/to/visium/{name}-label-image.zarr",
"md": {
"dimOrder": "XYZT",
"channel_names": ["Channel_1"],
"X": 10,
"Y": 10,
"Z": 1,
"C": 1,
"T": 0,
},
}
],
},
"options": {
"matrix": "X",
"factors": ["obs/sample"],
"mappings": {"obsm/X_umap": [0, 1]},
"spatial": {"xy": "obsm/spatial"},
},
"obs_type": "spot",
"is_spatial": True,
}
}
return dataset


def scrnaseq_dataset(name="scrnaseq_dataset"):
dataset = {
f"{name}": {
"file_paths": [f"test_project-{name}-anndata.zarr"],
"options": {
"matrix": "X",
"factors": ["obs/sample"],
"sets": ["obs/celltype"],
"mappings": {"obsm/X_umap": [0, 1]},
"spatial": {"xy": "obsm/spatial"},
},
"obs_type": "cell",
"is_spatial": False,
}
}
return dataset


class TestClass:
def test_build_config_multimodal(
self,
):
tests = [
(
"test-iss_visium_sc",
[iss_dataset(), visium_dataset(), scrnaseq_dataset()],
),
(
"test-visium_visium",
[
visium_dataset("visium_dataset_1"),
visium_dataset("visium_dataset_2"),
],
),
(
"test-sc_sc",
[scrnaseq_dataset("sc_dataset_1"), scrnaseq_dataset("sc_dataset_2")],
),
(
"test-iss_iss_visium",
[
iss_dataset("iss_dataset_1"),
iss_dataset("iss_dataset_2"),
visium_dataset(),
],
),
("test-sc", [scrnaseq_dataset()]),
("test-iss", [iss_dataset()]),
("test-visium", [visium_dataset()]),
]

for test in tests:
input = {
"project": test[0],
"extended_features": "celltype",
"url": "http://localhost/",
"config_filename_suffix": "config.json",
"datasets": reduce(operator.ior, test[1], {}),
}

write_json_multimodal(**input)

assert os.path.exists(
f"{input['project']}-multimodal-{input['config_filename_suffix']}"
)