Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(fix): allow init without X and correct shape inferred #1841

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/release-notes/1941.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Allow initialization of {class}`anndata.AnnData` objects without `X` (since they could be constructed previously by deleting `X`) {user}`ilan-gold`
42 changes: 40 additions & 2 deletions src/anndata/_core/anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from functools import partial
from pathlib import Path
from textwrap import dedent
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, cast

import h5py
import numpy as np
Expand Down Expand Up @@ -85,6 +85,40 @@ def _check_2d_shape(X):
raise ValueError(msg)


def _infer_shape_for_axis(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 things about these functions.

  1. Do they belong in utils? I see we have a bunch at the top of the file here so not sure
  2. Should we do some sort of validation that all the shapes match? I think this happens anyway elsewhere in the codebase e.g., https://github.com/scverse/anndata/blob/main/src/anndata/_core/aligned_mapping.py#L70-L100

xxx: pd.DataFrame | Mapping[str, Iterable[Any]] | None = None,
xxxm: np.ndarray | Mapping[str, Sequence[Any]] | None = None,
layers: Mapping[str, np.ndarray | sparse.spmatrix] | None = None,
xxxp: np.ndarray | Mapping[str, Sequence[Any]] | None = None,
) -> int | None:
for elem in [xxx, xxxm, xxxp]:
if elem is not None and hasattr(elem, "shape"):
return elem.shape[0]
for elem in [layers, xxxm, xxxp]:
if elem is not None:
elem = cast(Mapping, elem)
for sub_elem in elem.values():
if hasattr(sub_elem, "shape"):
size = cast(int, sub_elem.shape[0])
return size
return None


def _infer_shape(
obs: pd.DataFrame | Mapping[str, Iterable[Any]] | None = None,
var: pd.DataFrame | Mapping[str, Iterable[Any]] | None = None,
obsm: np.ndarray | Mapping[str, Sequence[Any]] | None = None,
varm: np.ndarray | Mapping[str, Sequence[Any]] | None = None,
layers: Mapping[str, np.ndarray | sparse.spmatrix] | None = None,
obsp: np.ndarray | Mapping[str, Sequence[Any]] | None = None,
varp: np.ndarray | Mapping[str, Sequence[Any]] | None = None,
):
return (
_infer_shape_for_axis(obs, obsm, layers, obsp),
_infer_shape_for_axis(var, varm, layers, varp),
)


class AnnData(metaclass=utils.DeprecationMixinMeta):
"""\
An annotated data matrix.
Expand Down Expand Up @@ -431,7 +465,11 @@ def _init_as_actual(
source = "X"
else:
self._X = None
n_obs, n_vars = (None, None) if shape is None else shape
n_obs, n_vars = (
shape
if shape is not None
else _infer_shape(obs, var, obsm, varm, layers, obsp, varp)
)
source = "shape"

# annotations
Expand Down
21 changes: 21 additions & 0 deletions tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import re
import warnings
from itertools import product
from typing import TYPE_CHECKING

import numpy as np
import pandas as pd
Expand All @@ -11,11 +12,16 @@
from scipy import sparse as sp
from scipy.sparse import csr_matrix, issparse

import anndata as ad
from anndata import AnnData, ImplicitModificationWarning
from anndata._settings import settings
from anndata.compat import CAN_USE_SPARSE_ARRAY
from anndata.tests.helpers import assert_equal, gen_adata, get_multiindex_columns_df

if TYPE_CHECKING:
from pathlib import Path
from typing import Literal

# some test objects that we use below
adata_dense = AnnData(np.array([[1, 2], [3, 4]]))
adata_dense.layers["test"] = adata_dense.X
Expand Down Expand Up @@ -735,3 +741,18 @@ def test_to_memory_no_copy():
assert mem.obsp[key] is adata.obsp[key]
for key in adata.varp:
assert mem.varp[key] is adata.varp[key]


@pytest.mark.parametrize("axis", ["obs", "var"])
@pytest.mark.parametrize("elem_type", ["p", "m"])
def test_create_adata_from_single_axis_elem(
axis: Literal["obs", "var"], elem_type: Literal["m", "p"], tmp_path: Path
):
d = dict(
a=np.zeros((10, 10)),
)
in_memory = AnnData(**{f"{axis}{elem_type}": d})
assert in_memory.shape == (10, 0) if axis == "obs" else (0, 10)
in_memory.write_h5ad(tmp_path / "adata.h5ad")
from_disk = ad.read_h5ad(tmp_path / "adata.h5ad")
ad.tests.helpers.assert_equal(from_disk, in_memory)
32 changes: 19 additions & 13 deletions tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,28 @@
from anndata.io import read_loom
from anndata.tests.helpers import gen_typed_df_t2_size

X = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
X_ = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
L = np.array([[10, 11, 12], [13, 14, 15], [16, 17, 18]])


def test_creation():
@pytest.fixture(params=[X_, None])
def X(request):
return request.param


def test_creation(X: np.ndarray | None):
adata = AnnData(X=X, layers=dict(L=L.copy()))

assert list(adata.layers.keys()) == ["L"]
assert "L" in adata.layers
assert "X" not in adata.layers
assert "some_other_thing" not in adata.layers
assert (adata.layers["L"] == L).all()
assert adata.shape == L.shape


def test_views():
adata = AnnData(X=X, layers=dict(L=L.copy()))
adata = AnnData(X=X_, layers=dict(L=L.copy()))
adata_view = adata[1:, 1:]

assert adata_view.layers.is_view
Expand All @@ -36,13 +42,13 @@ def test_views():
assert adata_view.layers.keys() == adata.layers.keys()
assert (adata_view.layers["L"] == adata.layers["L"][1:, 1:]).all()

adata.layers["S"] = X
adata.layers["S"] = X_

assert adata_view.layers.keys() == adata.layers.keys()
assert (adata_view.layers["S"] == adata.layers["S"][1:, 1:]).all()

with pytest.warns(ImplicitModificationWarning):
adata_view.layers["T"] = X[1:, 1:]
adata_view.layers["T"] = X_[1:, 1:]

assert not adata_view.layers.is_view
assert not adata_view.is_view
Expand All @@ -51,12 +57,12 @@ def test_views():
@pytest.mark.parametrize(
("df", "homogenous", "dtype"),
[
(lambda: gen_typed_df_t2_size(*X.shape), True, np.object_),
(lambda: pd.DataFrame(X**2), False, np.int_),
(lambda: gen_typed_df_t2_size(*X_.shape), True, np.object_),
(lambda: pd.DataFrame(X_**2), False, np.int_),
],
)
def test_set_dataframe(homogenous, df, dtype):
adata = AnnData(X)
adata = AnnData(X_)
if homogenous:
with pytest.warns(UserWarning, match=r"Layer 'df'.*dtype object"):
adata.layers["df"] = df()
Expand All @@ -68,7 +74,7 @@ def test_set_dataframe(homogenous, df, dtype):
assert np.issubdtype(adata.layers["df"].dtype, dtype)


def test_readwrite(backing_h5ad):
def test_readwrite(X: np.ndarray | None, backing_h5ad):
adata = AnnData(X=X, layers=dict(L=L.copy()))
adata.write(backing_h5ad)
adata_read = read_h5ad(backing_h5ad)
Expand All @@ -80,7 +86,7 @@ def test_readwrite(backing_h5ad):
@pytest.mark.skipif(find_spec("loompy") is None, reason="loompy not installed")
def test_readwrite_loom(tmp_path):
loom_path = tmp_path / "test.loom"
adata = AnnData(X=X, layers=dict(L=L.copy()))
adata = AnnData(X=X_, layers=dict(L=L.copy()))

with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=NumbaDeprecationWarning)
Expand All @@ -104,7 +110,7 @@ def test_backed():


def test_copy():
adata = AnnData(X=X, layers=dict(L=L.copy()))
adata = AnnData(X=X_, layers=dict(L=L.copy()))
bdata = adata.copy()
# check that we don’t create too many references
assert bdata._layers is bdata.layers._data
Expand All @@ -114,7 +120,7 @@ def test_copy():


def test_shape_error():
adata = AnnData(X=X)
adata = AnnData(X=X_)
with pytest.raises(
ValueError,
match=(
Expand All @@ -123,4 +129,4 @@ def test_shape_error():
r"Value had shape \(4, 3\) while it should have had \(3, 3\)\."
),
):
adata.layers["L"] = np.zeros((X.shape[0] + 1, X.shape[1]))
adata.layers["L"] = np.zeros((X_.shape[0] + 1, X_.shape[1]))