2
2
import anndata
3
3
import pandas as pd
4
4
import numpy as np
5
- from scipy .sparse import issparse
5
+ from scipy .sparse import issparse , spmatrix
6
6
from mudata import MuData
7
7
from pathlib import Path
8
8
from pandas .testing import assert_frame_equal
9
9
from typing import Literal
10
10
from .typing import AnnotationObjectOrPathLike
11
+ from functools import singledispatch
11
12
12
13
13
14
def _read_if_needed (anndata_mudata_path_or_obj ):
14
15
if isinstance (anndata_mudata_path_or_obj , (str , Path )):
15
- return mudata .read (anndata_mudata_path_or_obj )
16
+ return mudata .read (str ( anndata_mudata_path_or_obj )) # TODO: remove when mudata fixes PAth bug
16
17
if isinstance (anndata_mudata_path_or_obj , (mudata .MuData , anndata .AnnData )):
17
18
return anndata_mudata_path_or_obj .copy ()
18
19
raise AssertionError ("Expected 'Path', 'str' to MuData/AnnData "
@@ -64,6 +65,12 @@ def assert_var_names_equal(left: AnnotationObjectOrPathLike, right: AnnotationOb
64
65
assert_var_names_equal (modality , right [mod_name ])
65
66
66
67
68
+ def _assert_frame_equal (left , right , sort = False , * args , ** kwargs ):
69
+ if sort :
70
+ left , right = left .sort_index (inplace = False ), right .sort_index (inplace = False )
71
+ left , right = left .sort_index (axis = 1 , inplace = False ), right .sort_index (axis = 1 , inplace = False )
72
+ assert_frame_equal (left , right , * args , ** kwargs )
73
+
67
74
def assert_annotation_frame_equal (annotation_attr : Literal ["obs" , "var" ],
68
75
left : AnnotationObjectOrPathLike , right : AnnotationObjectOrPathLike ,
69
76
sort = False , * args , ** kwargs ):
@@ -72,9 +79,7 @@ def assert_annotation_frame_equal(annotation_attr: Literal["obs", "var"],
72
79
left , right = _read_if_needed (left ), _read_if_needed (right )
73
80
_assert_same_annotation_object_class (left , right )
74
81
left_frame , right_frame = getattr (left , annotation_attr ), getattr (right , annotation_attr )
75
- if sort :
76
- left_frame , right_frame = left_frame .sort_index (inplace = False ), right_frame .sort_index (inplace = False )
77
- assert_frame_equal (left_frame , right_frame , * args , ** kwargs )
82
+ _assert_frame_equal (left_frame , right_frame , sort = sort , * args , ** kwargs )
78
83
if isinstance (left , MuData ):
79
84
assert_mudata_modality_keys_equal (left , right )
80
85
for mod_name , modality in left .mod .items ():
@@ -123,13 +128,49 @@ def assert_layers_equal(left: AnnotationObjectOrPathLike,
123
128
assert_layers_equal (modality , right [mod_name ])
124
129
125
130
131
+
132
+ def assert_multidimensional_annotation_equal (annotation_attr : Literal ["obsm" , "varm" ],
133
+ left , right , sort = False ):
134
+ if not annotation_attr in ("obsm" , "varm" ):
135
+ raise ValueError ("annotation_attr should be 'obsm', or 'varm'" )
136
+ left , right = _read_if_needed (left ), _read_if_needed (right )
137
+ _assert_same_annotation_object_class (left , right )
138
+
139
+ @singledispatch
140
+ def _assert_multidimensional_value_equal (left , right , ** kwargs ):
141
+ raise NotImplementedError ("Unregistered type found while asserting" )
142
+
143
+ @_assert_multidimensional_value_equal .register
144
+ def _ (left : pd .DataFrame , right , ** kwargs ):
145
+ _assert_frame_equal (left , right , ** kwargs )
146
+
147
+ @_assert_multidimensional_value_equal .register (np .ndarray )
148
+ @_assert_multidimensional_value_equal .register (spmatrix )
149
+ def _ (left , right , ** kwargs ):
150
+ # Cannot sort sparse and dense matrices so ignore sort param
151
+ _assert_layer_equal (left , right )
152
+
153
+ left_dict , right_dict = getattr (left , annotation_attr ), getattr (right , annotation_attr )
154
+ left_keys , right_keys = left_dict .keys (), right_dict .keys ()
155
+ assert left_keys == right_keys , f"Keys of { annotation_attr } differ:\n [left]:{ left_keys } \n [right]:{ right_keys } "
156
+ for left_key , left_value in left_dict .items ():
157
+ _assert_multidimensional_value_equal (left_value , right_dict [left_key ], sort = sort )
158
+ if isinstance (left , MuData ):
159
+ assert_mudata_modality_keys_equal (left , right )
160
+ for mod_name , modality in left .mod .items ():
161
+ assert_multidimensional_annotation_equal (annotation_attr ,modality , right [mod_name ], sort = sort )
162
+
163
+
126
164
def assert_annotation_objects_equal (left : AnnotationObjectOrPathLike ,
127
165
right : AnnotationObjectOrPathLike ,
128
- check_data = True ):
166
+ check_data = True ,
167
+ sort = True ):
129
168
left , right = _read_if_needed (left ), _read_if_needed (right )
130
169
_assert_same_annotation_object_class (left , right )
131
170
assert_shape_equal (left , right )
132
- assert_annotation_frame_equal ("obs" , left , right )
133
- assert_annotation_frame_equal ("var" , left , right )
171
+ assert_annotation_frame_equal ("obs" , left , right , sort = sort )
172
+ assert_annotation_frame_equal ("var" , left , right , sort = sort )
173
+ assert_multidimensional_annotation_equal ("varm" , left , right , sort = sort )
174
+ assert_multidimensional_annotation_equal ("obsm" , left , right , sort = sort )
134
175
if check_data :
135
- assert_layers_equal (left , right )
176
+ assert_layers_equal (left , right )
0 commit comments