Skip to content

Commit fdbde57

Browse files
committed
Add suzuki method
1 parent 11beedc commit fdbde57

File tree

3 files changed

+872
-0
lines changed

3 files changed

+872
-0
lines changed
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
__merge__: ../../api/comp_method.yaml
2+
name: suzuki_mlp
3+
label: Suzuki MLP
4+
summary: Shuji Suzuki's winning neural network solution for multimodal single-cell integration.
5+
description: |
6+
A hierarchical neural network encoder-decoder model based on Shuji Suzuki's 1st place solution
7+
in the Open Problems Multimodal Single-Cell Integration competition. The model uses task-specific
8+
preprocessing, SVD dimensionality reduction, and hierarchical MLP blocks with residual connections
9+
for learning cross-modal mappings.
10+
references:
11+
doi: 10.1038/s41592-022-01652-z
12+
info:
13+
documentation_url: https://github.com/shu65/open-problems-multimodal
14+
repository_url: https://github.com/shu65/open-problems-multimodal
15+
preferred_normalization: log_cp10k
16+
arguments:
17+
# Task configuration
18+
- name: "--task_type"
19+
type: "string"
20+
default: "auto"
21+
description: Task type - 'auto' for automatic detection, 'cite' for CITE-seq, 'multi' for multiome.
22+
23+
# Preprocessing arguments
24+
- name: "--inputs_n_components"
25+
type: "integer"
26+
default: 128
27+
description: Number of SVD components for input modality dimensionality reduction.
28+
- name: "--targets_n_components"
29+
type: "integer"
30+
default: 128
31+
description: Number of SVD components for target modality dimensionality reduction.
32+
33+
# Model architecture arguments
34+
- name: "--encoder_h_dim"
35+
type: "integer"
36+
default: 512
37+
description: Hidden dimension size for the encoder.
38+
- name: "--decoder_h_dim"
39+
type: "integer"
40+
default: 512
41+
description: Hidden dimension size for the decoder.
42+
- name: "--n_encoder_block"
43+
type: "integer"
44+
default: 3
45+
description: Number of encoder blocks.
46+
- name: "--n_decoder_block"
47+
type: "integer"
48+
default: 3
49+
description: Number of decoder blocks.
50+
- name: "--dropout_p"
51+
type: "double"
52+
default: 0.1
53+
description: Dropout probability.
54+
- name: "--activation"
55+
type: "string"
56+
default: "relu"
57+
description: Activation function (relu or gelu).
58+
- name: "--norm"
59+
type: "string"
60+
default: "layer_norm"
61+
description: Normalization type (layer_norm or batch_norm).
62+
- name: "--use_skip_connections"
63+
type: "boolean"
64+
default: true
65+
description: Whether to use skip connections in blocks.
66+
67+
# Training arguments
68+
- name: "--learning_rate"
69+
type: "double"
70+
default: 1e-4
71+
description: Learning rate for training.
72+
- name: "--weight_decay"
73+
type: "double"
74+
default: 1e-6
75+
description: Weight decay for regularization.
76+
- name: "--epochs"
77+
type: "integer"
78+
default: 40
79+
description: Number of training epochs.
80+
- name: "--batch_size"
81+
type: "integer"
82+
default: 64
83+
description: Batch size for training.
84+
- name: "--use_residual_connections"
85+
type: "boolean"
86+
default: true
87+
description: Whether to use residual connections for multi task.
88+
89+
resources:
90+
- type: python_script
91+
path: script.py
92+
- path: utils.py
93+
engines:
94+
- type: docker
95+
image: openproblems/base_pytorch_nvidia:1.0.0
96+
runners:
97+
- type: executable
98+
- type: nextflow
99+
directives:
100+
label: [hightime, highmem, midcpu, gpu]

src/methods/suzuki_mlp/script.py

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
import sys
2+
import logging
3+
import anndata as ad
4+
import numpy as np
5+
import pandas as pd
6+
import torch
7+
import torch.nn as nn
8+
from sklearn.decomposition import TruncatedSVD, PCA
9+
from sklearn.preprocessing import StandardScaler
10+
from scipy import sparse
11+
import gc
12+
import warnings
13+
warnings.filterwarnings('ignore')
14+
15+
## VIASH START
16+
par = {
17+
'input_train_mod1': 'resources_test/task_predict_modality/openproblems_neurips2021/bmmc_cite/swap/train_mod1.h5ad',
18+
'input_train_mod2': 'resources_test/task_predict_modality/openproblems_neurips2021/bmmc_cite/swap/train_mod2.h5ad',
19+
'input_test_mod1': 'resources_test/task_predict_modality/openproblems_neurips2021/bmmc_cite/swap/test_mod1.h5ad',
20+
'output': 'output.h5ad',
21+
'task_type': 'auto',
22+
'inputs_n_components': 128,
23+
'targets_n_components': 128,
24+
'encoder_h_dim': 512,
25+
'decoder_h_dim': 512,
26+
'n_encoder_block': 3,
27+
'n_decoder_block': 3,
28+
'dropout_p': 0.1,
29+
'activation': 'relu',
30+
'norm': 'layer_norm',
31+
'use_skip_connections': True,
32+
'learning_rate': 0.0001,
33+
'weight_decay': 0.000001,
34+
'epochs': 40,
35+
'batch_size': 64,
36+
'use_residual_connections': True,
37+
}
38+
meta = {
39+
'name': 'suzuki_mlp'
40+
}
41+
## VIASH END
42+
43+
# Import utils functions
44+
import sys
45+
import os
46+
sys.path.append(meta["resources_dir"])
47+
48+
from utils import (
49+
determine_task_type, preprocess_data, train_model,
50+
MLPBModule, HierarchicalMLPBModule, SuzukiEncoderDecoderModule
51+
)
52+
53+
def main():
54+
# Enable logging
55+
logging.basicConfig(level=logging.INFO)
56+
57+
# Determine device
58+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
59+
print(f"Using device: {device}", flush=True)
60+
61+
# Read input files
62+
print("Reading input files", flush=True)
63+
adata_train_mod1 = ad.read_h5ad(par['input_train_mod1'])
64+
adata_train_mod2 = ad.read_h5ad(par['input_train_mod2'])
65+
adata_test_mod1 = ad.read_h5ad(par['input_test_mod1'])
66+
67+
# Determine task type
68+
if par['task_type'] == 'auto':
69+
task_type = determine_task_type(adata_train_mod1, adata_train_mod2)
70+
print(f"Auto-detected task type: {task_type}", flush=True)
71+
else:
72+
task_type = par['task_type']
73+
74+
print(f"Task type: {task_type}", flush=True)
75+
print(f"Modality 1: {adata_train_mod1.uns.get('modality', 'Unknown')}, n_features: {adata_train_mod1.n_vars}")
76+
print(f"Modality 2: {adata_train_mod2.uns.get('modality', 'Unknown')}, n_features: {adata_train_mod2.n_vars}")
77+
78+
# Preprocess data
79+
print("Preprocessing data", flush=True)
80+
data = preprocess_data(
81+
adata_train_mod1=adata_train_mod1,
82+
adata_train_mod2=adata_train_mod2,
83+
adata_test_mod1=adata_test_mod1,
84+
task_type=task_type,
85+
inputs_n_components=par['inputs_n_components'],
86+
targets_n_components=par['targets_n_components']
87+
)
88+
89+
X_train = data['X_train']
90+
y_train = data['y_train']
91+
X_test = data['X_test']
92+
metadata_train = data['metadata_train']
93+
metadata_test = data['metadata_test']
94+
targets_decomposer_components = data['targets_decomposer_components']
95+
targets_global_median = data['targets_global_median']
96+
y_statistic = data['y_statistic']
97+
98+
print(f"Training data shape: X={X_train.shape}, y={y_train.shape}")
99+
print(f"Test data shape: X={X_test.shape}")
100+
101+
# Build model
102+
print("Building model", flush=True)
103+
input_dim = X_train.shape[1]
104+
output_dim = y_train.shape[1]
105+
106+
# Create encoder
107+
encoder = MLPBModule(
108+
input_dim=None, # Will be set in the main module
109+
output_dim=par['encoder_h_dim'],
110+
n_block=par['n_encoder_block'],
111+
h_dim=par['encoder_h_dim'],
112+
skip=par['use_skip_connections'],
113+
dropout_p=par['dropout_p'],
114+
activation=par['activation'],
115+
norm="layer_norm"
116+
)
117+
118+
# Create hierarchical decoder
119+
decoder = HierarchicalMLPBModule(
120+
input_dim=par['encoder_h_dim'],
121+
output_dim=None, # Will create multiple outputs
122+
n_block=par['n_decoder_block'],
123+
h_dim=par['decoder_h_dim'],
124+
skip=par['use_skip_connections'],
125+
dropout_p=par['dropout_p'],
126+
activation=par['activation'],
127+
norm="layer_norm"
128+
)
129+
130+
# Create main model
131+
model = SuzukiEncoderDecoderModule(
132+
x_dim=input_dim,
133+
y_dim=output_dim,
134+
y_statistic=y_statistic,
135+
encoder_h_dim=par['encoder_h_dim'],
136+
decoder_h_dim=par['decoder_h_dim'],
137+
n_decoder_block=par['n_decoder_block'],
138+
targets_decomposer_components=targets_decomposer_components,
139+
targets_global_median=targets_global_median,
140+
encoder=encoder,
141+
decoder=decoder,
142+
task_type=task_type,
143+
use_residual_connections=par['use_residual_connections']
144+
).to(device)
145+
146+
# Train model
147+
print("Training model", flush=True)
148+
trained_model = train_model(
149+
model=model,
150+
X_train=X_train,
151+
y_train=y_train,
152+
metadata_train=metadata_train,
153+
device=device,
154+
lr=par['learning_rate'],
155+
weight_decay=par['weight_decay'],
156+
epochs=par['epochs'],
157+
batch_size=par['batch_size'],
158+
task_type=task_type
159+
)
160+
161+
# Predict on test data
162+
print("Predicting on test data", flush=True)
163+
trained_model.eval()
164+
predictions = []
165+
166+
with torch.no_grad():
167+
# Handle metadata safely for test data
168+
if 'gender' in metadata_test.columns:
169+
gender_values = metadata_test['gender'].values
170+
if gender_values.dtype == object:
171+
gender_values = pd.to_numeric(gender_values, errors='coerce').fillna(0).astype(int)
172+
gender_test = torch.LongTensor(gender_values)
173+
else:
174+
gender_test = torch.LongTensor(np.zeros(len(X_test), dtype=int))
175+
176+
info_test = torch.FloatTensor(np.zeros((len(X_test), 1)))
177+
178+
test_dataset = torch.utils.data.TensorDataset(
179+
torch.FloatTensor(X_test),
180+
gender_test,
181+
info_test
182+
)
183+
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=par['batch_size'], shuffle=False)
184+
185+
for batch_x, batch_gender, batch_info in test_loader:
186+
batch_x = batch_x.to(device)
187+
batch_gender = batch_gender.to(device)
188+
batch_info = batch_info.to(device)
189+
190+
pred = trained_model.predict(batch_x, batch_gender, batch_info)
191+
predictions.append(pred.cpu().numpy())
192+
193+
y_pred = np.vstack(predictions)
194+
195+
# Create output AnnData object
196+
print("Creating output", flush=True)
197+
adata_pred = ad.AnnData(
198+
obs=adata_test_mod1.obs.copy(),
199+
var=adata_train_mod2.var.copy(),
200+
layers={
201+
'normalized': y_pred
202+
},
203+
uns={
204+
'dataset_id': adata_train_mod1.uns.get('dataset_id', 'unknown'),
205+
'method_id': meta['name']
206+
}
207+
)
208+
209+
# Write output
210+
print("Writing output to file", flush=True)
211+
adata_pred.write_h5ad(par['output'], compression='gzip')
212+
213+
print("Done!", flush=True)
214+
215+
if __name__ == '__main__':
216+
main()

0 commit comments

Comments
 (0)