@@ -34,18 +34,7 @@ def factory(cfg):
34
34
raise Exception ("Adding noise to sparse data format is yet to be supported" )
35
35
cfg .expr .iloc [:] = cfg .expr .values + np .random .normal (loc = 0 , scale = cfg .add_noise_level , size = cfg .expr .shape )
36
36
37
- # prepare training placeholders
38
- cfg .l1_lambda_placeholder = tf .compat .v1 .placeholder (tf .float32 , name = 'l1_lambda' )
39
- cfg .l2_lambda_placeholder = tf .compat .v1 .placeholder (tf .float32 , name = 'l2_lambda' )
40
- cfg .lr = tf .compat .v1 .placeholder (tf .float32 , name = 'lr' )
41
-
42
- # Prepare dataset iterators
43
- dataset = tf .data .Dataset .from_tensor_slices ((cfg .pert_in , cfg .expr_out ))
44
- cfg .iter_train = tf .compat .v1 .data .make_initializable_iterator (
45
- dataset .shuffle (buffer_size = 1024 , reshuffle_each_iteration = True ).batch (cfg .batchsize ))
46
- cfg .iter_monitor = tf .compat .v1 .data .make_initializable_iterator (
47
- dataset .repeat ().shuffle (buffer_size = 1024 , reshuffle_each_iteration = True ).batch (cfg .batchsize ))
48
- cfg .iter_eval = tf .compat .v1 .data .make_initializable_iterator (dataset .batch (cfg .batchsize ))
37
+ cfg = get_tensors (cfg )
49
38
50
39
# Data partition
51
40
if cfg .experiment_type == 'random partition' or cfg .experiment_type == 'full data' :
@@ -81,6 +70,21 @@ def factory(cfg):
81
70
return cfg
82
71
83
72
73
+ def get_tensors (cfg ):
74
+ # prepare training placeholders
75
+ cfg .l1_lambda_placeholder = tf .compat .v1 .placeholder (tf .float32 , name = 'l1_lambda' )
76
+ cfg .l2_lambda_placeholder = tf .compat .v1 .placeholder (tf .float32 , name = 'l2_lambda' )
77
+ cfg .lr = tf .compat .v1 .placeholder (tf .float32 , name = 'lr' )
78
+
79
+ # Prepare dataset iterators
80
+ dataset = tf .data .Dataset .from_tensor_slices ((cfg .pert_in , cfg .expr_out ))
81
+ cfg .iter_train = tf .compat .v1 .data .make_initializable_iterator (
82
+ dataset .shuffle (buffer_size = 1024 , reshuffle_each_iteration = True ).batch (cfg .batchsize ))
83
+ cfg .iter_monitor = tf .compat .v1 .data .make_initializable_iterator (
84
+ dataset .repeat ().shuffle (buffer_size = 1024 , reshuffle_each_iteration = True ).batch (cfg .batchsize ))
85
+ cfg .iter_eval = tf .compat .v1 .data .make_initializable_iterator (dataset .batch (cfg .batchsize ))
86
+ return cfg
87
+
84
88
def s2c (cfg ):
85
89
"""single-to-combo trials"""
86
90
double_idx = cfg .loo .all (axis = 1 )
0 commit comments