Skip to content

Commit a4c0ffb

Browse files
committed
Create no_opt factory function for synthetic data
1 parent dd3f2fa commit a4c0ffb

File tree

1 file changed

+16
-12
lines changed

1 file changed

+16
-12
lines changed

pertbio/pertbio/dataset.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,18 +34,7 @@ def factory(cfg):
3434
raise Exception("Adding noise to sparse data format is yet to be supported")
3535
cfg.expr.iloc[:] = cfg.expr.values + np.random.normal(loc=0, scale=cfg.add_noise_level, size=cfg.expr.shape)
3636

37-
# prepare training placeholders
38-
cfg.l1_lambda_placeholder = tf.compat.v1.placeholder(tf.float32, name='l1_lambda')
39-
cfg.l2_lambda_placeholder = tf.compat.v1.placeholder(tf.float32, name='l2_lambda')
40-
cfg.lr = tf.compat.v1.placeholder(tf.float32, name='lr')
41-
42-
# Prepare dataset iterators
43-
dataset = tf.data.Dataset.from_tensor_slices((cfg.pert_in, cfg.expr_out))
44-
cfg.iter_train = tf.compat.v1.data.make_initializable_iterator(
45-
dataset.shuffle(buffer_size=1024, reshuffle_each_iteration=True).batch(cfg.batchsize))
46-
cfg.iter_monitor = tf.compat.v1.data.make_initializable_iterator(
47-
dataset.repeat().shuffle(buffer_size=1024, reshuffle_each_iteration=True).batch(cfg.batchsize))
48-
cfg.iter_eval = tf.compat.v1.data.make_initializable_iterator(dataset.batch(cfg.batchsize))
37+
cfg = get_tensors(cfg)
4938

5039
# Data partition
5140
if cfg.experiment_type == 'random partition' or cfg.experiment_type == 'full data':
@@ -81,6 +70,21 @@ def factory(cfg):
8170
return cfg
8271

8372

73+
def get_tensors(cfg):
74+
# prepare training placeholders
75+
cfg.l1_lambda_placeholder = tf.compat.v1.placeholder(tf.float32, name='l1_lambda')
76+
cfg.l2_lambda_placeholder = tf.compat.v1.placeholder(tf.float32, name='l2_lambda')
77+
cfg.lr = tf.compat.v1.placeholder(tf.float32, name='lr')
78+
79+
# Prepare dataset iterators
80+
dataset = tf.data.Dataset.from_tensor_slices((cfg.pert_in, cfg.expr_out))
81+
cfg.iter_train = tf.compat.v1.data.make_initializable_iterator(
82+
dataset.shuffle(buffer_size=1024, reshuffle_each_iteration=True).batch(cfg.batchsize))
83+
cfg.iter_monitor = tf.compat.v1.data.make_initializable_iterator(
84+
dataset.repeat().shuffle(buffer_size=1024, reshuffle_each_iteration=True).batch(cfg.batchsize))
85+
cfg.iter_eval = tf.compat.v1.data.make_initializable_iterator(dataset.batch(cfg.batchsize))
86+
return cfg
87+
8488
def s2c(cfg):
8589
"""single-to-combo trials"""
8690
double_idx = cfg.loo.all(axis=1)

0 commit comments

Comments
 (0)