-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsetup.py
351 lines (286 loc) · 13.3 KB
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
# Inital setup for stratified splitting of AnnData objects and verification of the splits. Needs to be run before training the models.
import numpy as np
from pathlib import Path
import anndata
import logging
from typing import List
import pandas as pd
from Code.config import config
# Configure logging for better control over output
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class InitialSetup:
"""
A class to perform initial setup tasks such as stratified splitting of AnnData objects
and verifying the integrity of the splits.
"""
def __init__(self):
"""
Initialize the InitialSetup class with the external configuration.
The configuration is expected to contain the following keys:
- strata: str, the column to use for stratification.
- tissue: str, the tissue type (e.g., 'head', 'body').
- seed: int, random seed for reproducibility.
- batch: bool, whether to use batch-corrected data.
- split_size: int, number of samples to split for evaluation.
"""
# Use external configuration
self.config_instance = config
logger.info("InitialSetup initialized with external configuration.")
def _resolve_data_path(self, tissue: str, batch: bool) -> Path:
"""
Resolve the path to the data directory based on tissue type and batch correction.
Args:
tissue (str): The tissue type (e.g., 'head', 'body').
batch (bool): Whether to use batch-corrected data.
Returns:
Path: The resolved data directory path.
"""
# Get the directory of the current script
code_dir = Path(__file__).resolve().parent
# Construct the base data directory path
data_dir = code_dir.parent / "TimeFlies" / "Data" / "h5ad" / tissue
# Append 'batch_corrected' or 'uncorrected' based on the batch flag
data_dir /= "batch_corrected" if batch else "uncorrected"
logger.debug(f"Resolved data directory: {data_dir}")
return data_dir
def _load_anndata(self, file_path: Path) -> anndata.AnnData:
"""
Load an AnnData object from a given file path.
Args:
file_path (Path): Path to the .h5ad file.
Returns:
anndata.AnnData: The loaded AnnData object.
"""
if not file_path.exists():
logger.error(f"File not found: {file_path}")
raise FileNotFoundError(f"File not found: {file_path}")
# Load the AnnData object and return a copy to avoid unintended modifications
adata = anndata.read_h5ad(file_path).copy()
logger.info(f"Loaded data from {file_path}")
return adata
def _select_samples(
self, strata: pd.Series, split_size: int, random_seed: int
) -> List[str]:
"""
Select samples for the evaluation set based on stratification.
Args:
strata (pd.Series): The stratification column.
split_size (int): Number of samples to split for evaluation.
random_seed (int): Seed for reproducibility.
Returns:
List[str]: List of selected sample indices.
"""
# Set the random seed for reproducibility
np.random.seed(random_seed)
# Calculate the proportion of each class
proportions = strata.value_counts(normalize=True)
logger.debug(f"Stratification proportions:\n{proportions}")
# Determine the number of samples to take from each class
n_samples = (proportions * split_size).round().astype(int)
logger.debug(f"Number of samples per class:\n{n_samples}")
selected_samples = []
# Iterate over each class and randomly select the required number of samples
for category, n in n_samples.items():
samples = strata.index[strata == category]
if len(samples) < n:
logger.warning(
f"Not enough samples for category '{category}'. "
f"Requested: {n}, available: {len(samples)}. Selecting all available samples."
)
n = len(samples) # Adjust to available samples
# Randomly choose samples without replacement
selected = np.random.choice(samples, size=n, replace=False)
selected_samples.extend(selected)
logger.debug(f"Selected {n} samples for category '{category}'.")
logger.info(f"Total selected samples for evaluation: {len(selected_samples)}")
return selected_samples
def stratified_split_and_save(
self,
strata_column: str,
tissue: str,
random_seed: int,
batch: bool,
split_size: int,
) -> None:
"""
Perform a stratified split on an AnnData object and save the resulting training and evaluation datasets.
Args:
strata_column (str): The column to use for stratification.
tissue (str): Tissue type to define the directory for saving.
random_seed (int): Seed for random number generator.
batch (bool): Whether to use batch-corrected data.
split_size (int): Number of samples to split for evaluation.
Returns:
None
"""
logger.info(f"Starting stratified split for tissue '{tissue}', batch={batch}.")
# Resolve the data directory path
data_dir = self._resolve_data_path(tissue, batch)
# Define the path to the original .h5ad file
h5ad_file_path = data_dir / "fly_original.h5ad"
# Load the AnnData object
adata = self._load_anndata(h5ad_file_path)
# Define the stratification variable
strata = adata.obs[strata_column]
logger.debug(
f"Stratification column '{strata_column}' loaded with {strata.nunique()} unique classes."
)
# Select samples for evaluation set based on stratification
selected_samples = self._select_samples(strata, split_size, random_seed)
# Create new AnnData objects for training and evaluation sets
adata_eval = adata[selected_samples, :].copy()
adata_train = adata[~adata.obs_names.isin(selected_samples), :].copy()
logger.info(f"Created evaluation dataset with {adata_eval.n_obs} samples.")
logger.info(f"Created training dataset with {adata_train.n_obs} samples.")
# Define paths for the split datasets
fly_train_path = data_dir / "fly_train.h5ad"
fly_eval_path = data_dir / "fly_eval.h5ad"
# Save the split datasets
adata_train.write_h5ad(fly_train_path)
adata_eval.write_h5ad(fly_eval_path)
logger.info(f"Saved training data to {fly_train_path}")
logger.info(f"Saved evaluation data to {fly_eval_path}")
def verify(
self,
tissue: str,
batch: bool,
strata_column: str,
) -> bool:
"""
Verify that the training and evaluation datasets independently maintain the same stratification proportions as the original dataset.
Args:
tissue (str): Tissue type.
batch (bool): Whether to use batch-corrected data.
strata_column (str): Column used for stratification.
Returns:
bool: True if verification passes for both datasets, False otherwise.
"""
logger.info(f"Starting verification for tissue '{tissue}', batch={batch}.")
# Resolve the data directory path
data_dir = self._resolve_data_path(tissue, batch)
# Define paths to the split datasets and original dataset
fly_eval_path = data_dir / "fly_eval.h5ad"
fly_train_path = data_dir / "fly_train.h5ad"
original_path = data_dir / "fly_original.h5ad"
# Load the AnnData objects
try:
eval_adata = self._load_anndata(fly_eval_path)
train_adata = self._load_anndata(fly_train_path)
original_adata = self._load_anndata(original_path)
except FileNotFoundError:
logger.error("One or more required files are missing for verification.")
return False
# Verify no overlap between evaluation and training datasets
eval_samples = set(eval_adata.obs_names)
train_samples = set(train_adata.obs_names)
overlap = eval_samples.intersection(train_samples)
if overlap:
logger.error(
f"Overlap detected between eval and train datasets for tissue '{tissue}', batch={batch}. "
f"Overlapping samples: {overlap}"
)
return False
else:
logger.info(
f"No overlap between eval and train datasets for tissue '{tissue}', batch={batch}."
)
# Since strata_columns is now a single column, simplify the aggregation
# Get counts for each stratum in the original, training, and evaluation datasets
original_counts = original_adata.obs[strata_column].value_counts().sort_index()
train_counts = train_adata.obs[strata_column].value_counts().sort_index()
eval_counts = eval_adata.obs[strata_column].value_counts().sort_index()
logger.debug(f"Original strata counts:\n{original_counts}")
logger.debug(f"Training strata counts:\n{train_counts}")
logger.debug(f"Evaluation strata counts:\n{eval_counts}")
# Calculate proportions
total_original = original_counts.sum()
total_train = train_counts.sum()
total_eval = eval_counts.sum()
original_proportions = original_counts / total_original
train_proportions = (
train_counts / total_train if total_train > 0 else pd.Series()
)
eval_proportions = eval_counts / total_eval if total_eval > 0 else pd.Series()
logger.debug(f"Original strata proportions:\n{original_proportions}")
logger.debug(f"Training strata proportions:\n{train_proportions}")
logger.debug(f"Evaluation strata proportions:\n{eval_proportions}")
# Define a tolerance level for the stratification check (e.g., 5%)
tolerance = 0.05
# Function to compare proportions
def check_proportions(
original: pd.Series, split: pd.Series, split_name: str
) -> bool:
for strata, orig_prop in original.items():
split_prop = split.get(strata, 0)
difference = abs(orig_prop - split_prop)
if difference > tolerance:
logger.error(
f"Stratification proportion for '{strata}' in {split_name} dataset differs by {difference:.2f}, which exceeds the tolerance of {tolerance}."
)
return False
return True
# Check proportions in training dataset
train_check = check_proportions(
original_proportions, train_proportions, "training"
)
# Check proportions in evaluation dataset
eval_check = check_proportions(
original_proportions, eval_proportions, "evaluation"
)
if train_check and eval_check:
logger.info(
f"Stratification integrity maintained for both training and evaluation datasets for tissue '{tissue}', batch={batch}."
)
return True
else:
logger.error(
f"Stratification integrity check failed for tissue '{tissue}', batch={batch}."
)
return False
def main(self):
"""
Execute the stratified split and verification based on the external configuration.
The method performs the following steps:
1. Extracts parameters from the configuration.
2. Performs the stratified split and saves the datasets.
3. Verifies the integrity of the split datasets.
"""
logger.info("Starting main execution of InitialSetup.")
# Extract split parameters from the configuration
strata_column = self.config_instance.Setup.strata
tissue = self.config_instance.Setup.tissue
seed = self.config_instance.Setup.seed
batch = self.config_instance.Setup.use_batch_corrected_data
split_size = self.config_instance.Setup.split_size
logger.info(
f"Configuration - Strata: {strata_column}, Tissue: {tissue}, Seed: {seed}, Batch: {batch}, Split Size: {split_size}"
)
# Perform the stratified split and save the datasets
try:
self.stratified_split_and_save(
strata_column=strata_column,
tissue=tissue,
random_seed=seed,
batch=batch,
split_size=split_size,
)
logger.info("Stratified split and save completed successfully.")
except Exception as e:
logger.exception(f"Error during stratified split and save: {e}")
return
# Perform verification
try:
verification_result = self.verify(
tissue=tissue, batch=batch, strata_column=strata_column
)
if verification_result:
logger.info("Verification successful.")
else:
logger.error("Verification failed.")
except Exception as e:
logger.exception(f"Error during verification: {e}")
logger.info("Main execution of InitialSetup completed.")
if __name__ == "__main__":
setup = InitialSetup()
setup.main()