Skip to content

Commit d81fa40

Browse files
changes for whoapsi talk'
1 parent 1f60fd7 commit d81fa40

File tree

7 files changed

+493
-27
lines changed

7 files changed

+493
-27
lines changed

snigdha/batch_fdr.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from copy import copy
2+
from argparse import Namespace
3+
4+
import numpy as np, pandas as pd
5+
6+
from compare_fdr import main as main_fdr
7+
from compare_intervals import main as main_intervals
8+
from compare_estimators import main as main_estimators
9+
10+
11+
base_opts = Namespace(all_methods=False,
12+
all_methods_noR=False,
13+
concat=False,
14+
cor_thresh=0.5,
15+
csvfile='estimators.csv',
16+
htmlfile='estimators.html',
17+
instance='AR_instance',
18+
level=0.2,
19+
list_instances=False,
20+
list_methods=False,
21+
methods=['liu_CV', 'lee_1se', 'randomized_lasso_half_1se', 'data_splitting_1se', 'randomized_BH'],
22+
n=3000,
23+
nsim=100,
24+
p=1000,
25+
rho=[0., 0.25, 0.5, 0.75][::-1],
26+
s=30,
27+
signal=[3.5],
28+
snr=1.,
29+
use_BH=True,
30+
verbose=True,
31+
wide_only=False)
32+
33+
# BH results
34+
35+
BH_opts = copy(base_opts)
36+
BH_opts.csvfile = 'fdr_BH.csv'
37+
BH_opts.htmlfile='fdr_BH.html'
38+
BH_opts.use_BH = True
39+
main_fdr(BH_opts)
40+
41+
# # estimator results
42+
43+
# estimator_opts = copy(base_opts)
44+
# estimator_opts.csvfile = 'estimation.csv'
45+
# estimator_opts.htmlfile='estimation.html'
46+
# estimator_opts.use_BH = True
47+
# main_estimators(estimator_opts)
48+
49+
# # interval results
50+
51+
# interval_opts = copy(base_opts)
52+
# interval_opts.csvfile = 'intervals.csv'
53+
# interval_opts.htmlfile='intervals.html'
54+
# interval_opts.use_BH = True
55+
# main_intervals(interval_opts)
56+
57+
# intervals = pd.read_csv('intervals_summary.csv')
58+
# estimation = pd.read_csv('estimation_summary.csv')
59+
# marginal = pd.read_csv('fdr_marginal_summary.csv')
60+
# BH = pd.read_csv('fdr_BH_summary.csv')
61+
62+
# half1 = pd.merge(marginal, BH, left_on=['snr', 'class_name'], right_on=['snr', 'class_name'],
63+
# suffixes=(' (marginal)', ' (BH)'))
64+
# half2 = pd.merge(estimation, intervals, left_on=['snr', 'class_name'], right_on=['snr', 'class_name'],
65+
# suffixes=(' (estimation)', ' (intervals)'))
66+
# full = pd.merge(half1, half2, left_on=['snr', 'class_name'], right_on=['snr', 'class_name'])
67+
68+
# full.to_csv('full_data.csv', index=False)
69+
70+
# columns_for_plots = pd.DataFrame({'median_length':full['Median Length'],
71+
# 'marginal_power':full['Full Model Power (marginal)'],
72+
# 'BH_power':full['Full Model Power (BH)'],
73+
# 'marginal_FDP':full['Full Model FDR (marginal)'],
74+
# 'BH_FDP':full['Full Model FDR (BH)'],
75+
# 'coverage':full['Coverage'],
76+
# 'snr':full['snr'],
77+
# 'class_name':full['class_name']})
78+
# columns_for_plots.to_csv('plotting_data.csv', index=False)

snigdha/batch_intervals.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
from copy import copy
2+
from argparse import Namespace
3+
4+
import numpy as np, pandas as pd
5+
6+
from compare_fdr import main as main_fdr
7+
from compare_intervals import main as main_intervals
8+
from compare_estimators import main as main_estimators
9+
10+
11+
base_opts = Namespace(all_methods=False,
12+
all_methods_noR=False,
13+
concat=False,
14+
cor_thresh=0.5,
15+
csvfile='estimators.csv',
16+
htmlfile='estimators.html',
17+
instance='AR_instance',
18+
level=0.2,
19+
list_instances=False,
20+
list_methods=False,
21+
methods=['liu_CV', 'lee_1se', 'randomized_lasso_half_1se', 'data_splitting_1se', 'randomized_BH'],
22+
n=3000,
23+
nsim=100,
24+
p=1000,
25+
rho=[0., 0.25, 0.5, 0.75][::-1],
26+
s=30,
27+
signal=[3.5],
28+
snr=1.,
29+
use_BH=True,
30+
verbose=True,
31+
wide_only=False,
32+
confidence=0.9)
33+
34+
# # BH results
35+
36+
# BH_opts = copy(base_opts)
37+
# BH_opts.csvfile = 'fdr_BH.csv'
38+
# BH_opts.htmlfile='fdr_BH.html'
39+
# BH_opts.use_BH = True
40+
# main_fdr(BH_opts)
41+
42+
# # estimator results
43+
44+
# estimator_opts = copy(base_opts)
45+
# estimator_opts.csvfile = 'estimation.csv'
46+
# estimator_opts.htmlfile='estimation.html'
47+
# estimator_opts.use_BH = True
48+
# main_estimators(estimator_opts)
49+
50+
# interval results
51+
52+
interval_opts = copy(base_opts)
53+
interval_opts.csvfile = 'intervals.csv'
54+
interval_opts.htmlfile='intervals.html'
55+
interval_opts.use_BH = True
56+
interval_opts.level = 0.1
57+
main_intervals(interval_opts)
58+
59+
# intervals = pd.read_csv('intervals_summary.csv')
60+
# estimation = pd.read_csv('estimation_summary.csv')
61+
# marginal = pd.read_csv('fdr_marginal_summary.csv')
62+
# BH = pd.read_csv('fdr_BH_summary.csv')
63+
64+
# half1 = pd.merge(marginal, BH, left_on=['snr', 'class_name'], right_on=['snr', 'class_name'],
65+
# suffixes=(' (marginal)', ' (BH)'))
66+
# half2 = pd.merge(estimation, intervals, left_on=['snr', 'class_name'], right_on=['snr', 'class_name'],
67+
# suffixes=(' (estimation)', ' (intervals)'))
68+
# full = pd.merge(half1, half2, left_on=['snr', 'class_name'], right_on=['snr', 'class_name'])
69+
70+
# full.to_csv('full_data.csv', index=False)
71+
72+
# columns_for_plots = pd.DataFrame({'median_length':full['Median Length'],
73+
# 'marginal_power':full['Full Model Power (marginal)'],
74+
# 'BH_power':full['Full Model Power (BH)'],
75+
# 'marginal_FDP':full['Full Model FDR (marginal)'],
76+
# 'BH_FDP':full['Full Model FDR (BH)'],
77+
# 'coverage':full['Coverage'],
78+
# 'snr':full['snr'],
79+
# 'class_name':full['class_name']})
80+
# columns_for_plots.to_csv('plotting_data.csv', index=False)

snigdha/compare_estimators.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,11 @@ def main(opts):
143143
else:
144144
snr_vals = [None]
145145

146+
if opts.signal is not None: # looping over snr strengths
147+
signal_vals = np.atleast_1d(opts.signal)
148+
else:
149+
signal_vals = [None]
150+
146151
new_opts = copy(opts)
147152
prev_rho = np.nan
148153

@@ -157,15 +162,17 @@ def main(opts):
157162
if opts.wide_only: # only allow methods that are ok if p>n
158163
new_opts.methods = [m for m in new_opts.methods if m.wide_OK]
159164

160-
for rho, snr in product(np.atleast_1d(opts.rho),
161-
snr_vals):
165+
for rho, snr, signal in product(np.atleast_1d(opts.rho),
166+
snr_vals,
167+
signal_vals):
162168

163169
# try to save some time on setup of knockoffs
164170

165171
method_setup = rho != prev_rho
166172
prev_rho = rho
167173

168174
new_opts.snr = snr
175+
new_opts.signal = signal
169176
new_opts.rho = rho
170177

171178
try:
@@ -187,10 +194,19 @@ def main(opts):
187194
if snr is not None: # here is where snr_fac can be ignored
188195
instance.snr = new_opts.snr
189196

197+
if signal is not None: # here is where snr_fac can be ignored
198+
instance.signal = new_opts.signal
199+
190200
if opts.csvfile is not None:
191-
new_opts.csvfile = (os.path.splitext(opts.csvfile)[0] +
192-
"_snr%0.1f_rho%0.2f.csv" % (new_opts.snr,
193-
new_opts.rho))
201+
if snr is not None:
202+
new_opts.csvfile = (os.path.splitext(opts.csvfile)[0] +
203+
"_snr%0.1f_rho%0.2f.csv" % (new_opts.snr,
204+
new_opts.rho))
205+
elif signal is not None:
206+
new_opts.csvfile = (os.path.splitext(opts.csvfile)[0] +
207+
"_signal%0.1f_rho%0.2f.csv" % (new_opts.signal,
208+
new_opts.rho))
209+
194210
csvfiles.append(new_opts.csvfile)
195211
summaryfiles.append(new_opts.csvfile.replace('.csv', '_summary.csv'))
196212

@@ -244,6 +260,9 @@ def main(opts):
244260
parser.add_argument('--snr', type=float, nargs='+',
245261
dest='snr',
246262
help='snr strength to override instance default (default value: None)')
263+
parser.add_argument('--signal', type=float, nargs='+',
264+
dest='signal',
265+
help='signal strength to override instance default (default value: None)')
247266
parser.add_argument('--rho', nargs='+', type=float,
248267
default=0.,
249268
dest='rho',
@@ -252,9 +271,9 @@ def main(opts):
252271
help='How many repetitions?')
253272
parser.add_argument('--verbose', action='store_true',
254273
dest='verbose')
255-
parser.add_argument('--htmlfile', help='HTML file to store results for one (snr, rho). When looping over (snr, rho) this HTML file tracks the current progress.',
274+
parser.add_argument('--htmlfile', help='HTML file to store results for one (snr, signal, rho). When looping over (snr, signal, rho) this HTML file tracks the current progress.',
256275
dest='htmlfile')
257-
parser.add_argument('--csvfile', help='CSV file to store results looped over (snr, rho). Serves as a file base for individual (snr, rho) pairs.',
276+
parser.add_argument('--csvfile', help='CSV file to store results looped over (snr, signal, rho). Serves as a file base for individual (snr, signal, rho) pairs.',
258277
dest='csvfile')
259278
parser.add_argument('--all_methods', help='Run all methods.',
260279
default=False,

snigdha/compare_fdr.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,11 @@ def main(opts):
161161
else:
162162
snr_vals = [None]
163163

164+
if opts.signal is not None: # looping over snr strengths
165+
signal_vals = np.atleast_1d(opts.signal)
166+
else:
167+
signal_vals = [None]
168+
164169
new_opts = copy(opts)
165170
prev_rho = np.nan
166171

@@ -175,15 +180,17 @@ def main(opts):
175180
if opts.wide_only: # only allow methods that are ok if p>n
176181
new_opts.methods = [m for m in new_opts.methods if m.wide_OK]
177182

178-
for rho, snr in product(np.atleast_1d(opts.rho),
179-
snr_vals):
183+
for rho, snr, signal in product(np.atleast_1d(opts.rho),
184+
snr_vals,
185+
signal_vals):
180186

181187
# try to save some time on setup of knockoffs
182188

183189
method_setup = rho != prev_rho
184190
prev_rho = rho
185191

186192
new_opts.snr = snr
193+
new_opts.signal = signal
187194
new_opts.rho = rho
188195

189196
try:
@@ -205,11 +212,18 @@ def main(opts):
205212

206213
if snr is not None: # here is where snr_fac can be ignored
207214
instance.snr = new_opts.snr
215+
if signal is not None:
216+
instance.signal = new_opts.signal
208217

209218
if opts.csvfile is not None:
210-
new_opts.csvfile = (os.path.splitext(opts.csvfile)[0] +
211-
"_snr%0.1f_rho%0.2f.csv" % (new_opts.snr,
212-
new_opts.rho))
219+
if snr is not None:
220+
new_opts.csvfile = (os.path.splitext(opts.csvfile)[0] +
221+
"_snr%0.1f_rho%0.2f.csv" % (new_opts.snr,
222+
new_opts.rho))
223+
elif signal is not None:
224+
new_opts.csvfile = (os.path.splitext(opts.csvfile)[0] +
225+
"_signal%0.1f_rho%0.2f.csv" % (new_opts.signal,
226+
new_opts.rho))
213227
csvfiles.append(new_opts.csvfile)
214228
summaryfiles.append(new_opts.csvfile.replace('.csv', '_summary.csv'))
215229

@@ -272,6 +286,9 @@ def main(opts):
272286
parser.add_argument('--snr', type=float, nargs='+',
273287
dest='snr',
274288
help='snr strength to override instance default (default value: None)')
289+
parser.add_argument('--signal', type=float, nargs='+',
290+
dest='signal',
291+
help='signal strength to override instance default (default value: None)')
275292
parser.add_argument('--rho', nargs='+', type=float,
276293
default=0.,
277294
dest='rho',
@@ -286,9 +303,9 @@ def main(opts):
286303
help='How many repetitions?')
287304
parser.add_argument('--verbose', action='store_true',
288305
dest='verbose')
289-
parser.add_argument('--htmlfile', help='HTML file to store results for one (snr, rho). When looping over (snr, rho) this HTML file tracks the current progress.',
306+
parser.add_argument('--htmlfile', help='HTML file to store results for one (snr, signal, rho). When looping over (snr, signalm rho) this HTML file tracks the current progress.',
290307
dest='htmlfile')
291-
parser.add_argument('--csvfile', help='CSV file to store results looped over (snr, rho). Serves as a file base for individual (snr, rho) pairs.',
308+
parser.add_argument('--csvfile', help='CSV file to store results looped over (snr, signal, rho). Serves as a file base for individual (snr, signal, rho) tuples.',
292309
dest='csvfile')
293310
parser.add_argument('--all_methods', help='Run all methods.',
294311
default=False,

0 commit comments

Comments
 (0)