Skip to content

Commit dd3f2fa

Browse files
committed
fix and support sparse matrix for s2c and loo
1 parent b21db22 commit dd3f2fa

File tree

1 file changed

+43
-19
lines changed

1 file changed

+43
-19
lines changed

pertbio/pertbio/dataset.py

Lines changed: 43 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -91,22 +91,33 @@ def s2c(cfg):
9191
ntrain = int(nvalid * cfg.validset_ratio)
9292

9393
valid_pos = np.random.choice(range(nvalid), nvalid, replace=False)
94-
pert_train = cfg.pert[~testidx]
95-
train_data = cfg.expr[~testidx]
9694
dataset = {
9795
"node_index": cfg.node_index,
98-
"pert_train": pert_train.iloc[valid_pos[:ntrain], :].values,
99-
"pert_valid": pert_train.iloc[valid_pos[ntrain:], :].values,
100-
"pert_test": cfg.pert[testidx],
10196
"pert_full": cfg.pert,
102-
"train_data": train_data.iloc[valid_pos[:ntrain], :].values,
103-
"valid_data": train_data.iloc[valid_pos[ntrain:], :].values,
104-
"test_data": cfg.expr[testidx],
105-
"train_pos": valid_pos[:ntrain].values,
106-
"valid_pos": valid_pos[ntrain:].values,
97+
"train_pos": valid_pos[:ntrain],
98+
"valid_pos": valid_pos[ntrain:],
10799
"test_pos": testidx
108100
}
109101

102+
if cfg.sparse_data:
103+
dataset.update({
104+
"pert_train": npz_to_feedable_arrays(cfg.pert[~testidx][valid_pos[:ntrain]]),
105+
"pert_valid": npz_to_feedable_arrays(cfg.pert[~testidx][valid_pos[ntrain:]]),
106+
"pert_test": npz_to_feedable_arrays(cfg.pert[testidx]),
107+
"expr_train": npz_to_feedable_arrays(cfg.expr[~testidx][valid_pos[:ntrain]]),
108+
"expr_valid": npz_to_feedable_arrays(cfg.expr[~testidx][valid_pos[ntrain:]]),
109+
"expr_test": npz_to_feedable_arrays(cfg.expr[testidx])
110+
})
111+
else:
112+
dataset.update({
113+
"pert_train": cfg.pert[~testidx].iloc[valid_pos[:ntrain], :].values,
114+
"pert_valid": cfg.pert[~testidx].iloc[valid_pos[ntrain:], :].values,
115+
"pert_test": cfg.pert[testidx],
116+
"expr_train": cfg.expr[~testidx].iloc[valid_pos[:ntrain], :].values,
117+
"expr_valid": cfg.expr[~testidx].iloc[valid_pos[ntrain:], :].values,
118+
"expr_test": cfg.expr[testidx]
119+
})
120+
110121
# TODO: class Dataset of Sample instances
111122

112123
return dataset
@@ -128,20 +139,33 @@ def loo(cfg, singles):
128139
ntrain = int(nvalid * cfg.validset_ratio)
129140

130141
valid_pos = np.random.choice(range(nvalid), nvalid, replace=False)
131-
pert_train = cfg.pert[~testidx]
132-
train_data = cfg.expr[~testidx]
133-
134142
dataset = {
135143
"node_index": cfg.node_index,
136-
"pert_train": pert_train.iloc[valid_pos[:ntrain], :],
137-
"pert_valid": pert_train.iloc[valid_pos[ntrain:], :],
138-
"pert_test": cfg.pert[testidx],
139144
"pert_full": cfg.pert,
140-
"train_data": train_data.iloc[valid_pos[:ntrain], :],
141-
"valid_data": train_data.iloc[valid_pos[ntrain:], :],
142-
"test_data": cfg.expr[testidx]
145+
"train_pos": valid_pos[:ntrain],
146+
"valid_pos": valid_pos[ntrain:],
147+
"test_pos": testidx
143148
}
144149

150+
if cfg.sparse_data:
151+
dataset.update({
152+
"pert_train": npz_to_feedable_arrays(cfg.pert[~testidx][valid_pos[:ntrain]]),
153+
"pert_valid": npz_to_feedable_arrays(cfg.pert[~testidx][valid_pos[ntrain:]]),
154+
"pert_test": npz_to_feedable_arrays(cfg.pert[testidx]),
155+
"expr_train": npz_to_feedable_arrays(cfg.expr[~testidx][valid_pos[:ntrain]]),
156+
"expr_valid": npz_to_feedable_arrays(cfg.expr[~testidx][valid_pos[ntrain:]]),
157+
"expr_test": npz_to_feedable_arrays(cfg.expr[testidx])
158+
})
159+
else:
160+
dataset.update({
161+
"pert_train": cfg.pert[~testidx].iloc[valid_pos[:ntrain], :].values,
162+
"pert_valid": cfg.pert[~testidx].iloc[valid_pos[ntrain:], :].values,
163+
"pert_test": cfg.pert[testidx],
164+
"expr_train": cfg.expr[~testidx].iloc[valid_pos[:ntrain], :].values,
165+
"expr_valid": cfg.expr[~testidx].iloc[valid_pos[ntrain:], :].values,
166+
"expr_test": cfg.expr[testidx]
167+
})
168+
145169
return dataset
146170

147171

0 commit comments

Comments
 (0)