Skip to content

Commit fa809e9

Browse files
authored
Merge pull request #1 from echaussidon/blinding_edmond
Blinding edmond
2 parents a0ad0c2 + d748e2f commit fa809e9

File tree

1 file changed

+126
-98
lines changed

1 file changed

+126
-98
lines changed

scripts/main/apply_blinding_main_fromfile_fcomp.py

+126-98
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import sys
1919
import os
20+
import logging
2021
import shutil
2122
import unittest
2223
from datetime import datetime
@@ -59,6 +60,18 @@
5960
sys.exit('NERSC_HOST not known (code only works on NERSC), not proceeding')
6061

6162

63+
try:
64+
mpicomm = pyrecon.mpi.COMM_WORLD # MPI version
65+
except AttributeError:
66+
mpicomm = None # non-MPI version
67+
sys.exit('The following script need to be run with the MPI version of pyrecon. Please use module swap pyrecon:mpi')
68+
root = mpicomm.rank == 0
69+
70+
71+
# to remove jax warning (from cosmoprimo)
72+
logging.getLogger("jax._src.lib.xla_bridge").addFilter(logging.Filter("No GPU/TPU found, falling back to CPU."))
73+
74+
6275
parser = argparse.ArgumentParser()
6376
parser.add_argument("--type", help="tracer type to be selected")
6477
parser.add_argument("--basedir_in", help="base directory for input, default is location for official catalogs",default='/global/cfs/cdirs/desi/survey/catalogs/')
@@ -80,110 +93,72 @@
8093
parser.add_argument("--maxr", help="maximum for random files, default is 1",default=1,type=int) #use 2 for abacus mocks
8194
parser.add_argument("--dorecon",help="if y, run the recon needed for RSD blinding",default='n')
8295
parser.add_argument("--rsdblind",help="if y, do the RSD blinding shift",default='n')
96+
parser.add_argument("--fnlblind",help="if y, do the fnl blinding",default='n')
8397

8498
parser.add_argument("--fiducial_f",help="fiducial value for f",default=0.8)
8599

86-
parser.add_argument("--visnz",help="whether to look at the original, blinded, and weighted n(z)",default='n')
87-
88-
89100
#parser.add_argument("--fix_monopole",help="whether to choose f such that the amplitude of the monopole is fixed",default='y')
90101

91-
92102
args = parser.parse_args()
93-
94-
try:
95-
mpicomm = pyrecon.mpi.COMM_WORLD # MPI version
96-
except AttributeError:
97-
mpicomm = None # non-MPI version
98-
root = mpicomm is None or mpicomm.rank == 0
99-
100-
101-
if root:
102-
print(args)
103+
if root: print(args)
103104

104105
type = args.type
105106
version = args.version
106107
specrel = args.verspec
107108

108-
notqso = ''
109-
if args.notqso == 'y':
110-
notqso = 'notqso'
111-
112-
if root:
113-
print('blinding catalogs for tracer type '+type+notqso)
114-
115-
116-
if type[:3] == 'BGS' or type == 'bright' or type == 'MWS_ANY':
117-
prog = 'BRIGHT'
118-
119-
else:
120-
prog = 'DARK'
109+
notqso = 'notqso' if (args.notqso == 'y') else ''
110+
if root: print('blinding catalogs for tracer type ' + type + notqso)
121111

112+
prog = 'BRIGHT' if (type[:3] == 'BGS' or type == 'bright' or type == 'MWS_ANY') else 'DARK'
122113
progl = prog.lower()
123114

124115
mainp = main(args.type)
125116
zmin = mainp.zmin
126117
zmax = mainp.zmax
127118
tsnrcol = mainp.tsnrcol
128119

129-
130120
#share basedir location '/global/cfs/cdirs/desi/survey/catalogs'
131121
if 'mock' not in args.verspec:
132-
maindir = args.basedir_in +'/'+args.survey+'/LSS/'
133-
134-
ldirspec = maindir+specrel+'/'
135-
136-
dirin = ldirspec+'LSScats/'+version+'/'
137-
LSSdir = ldirspec+'LSScats/'
122+
maindir = args.basedir_in + '/' + args.survey + '/LSS/'
123+
ldirspec = maindir + specrel + '/'
124+
dirin = ldirspec + 'LSScats/' + version + '/'
138125
tsnrcut = mainp.tsnrcut
139126
dchi2 = mainp.dchi2
140127
randens = 2500.
141128
nzmd = 'data'
142129
elif 'Y1/mock' in args.verspec: #e.g., use 'mocks/FirstGenMocks/AbacusSummit/Y1/mock1' to get the 1st mock with fiberassign
143-
dirin = args.basedir_in +'/'+args.survey+'/'+args.verspec+'/LSScats/'+version+'/'
144-
LSSdir = args.basedir_in +'/'+args.survey+'/'+args.verspec+'/LSScats/'
130+
dirin = args.basedir_in + '/' + args.survey + '/' + args.verspec + '/LSScats/' + version + '/'
145131
dchi2=None
146132
tsnrcut=0
147133
randens = 10460.
148134
nzmd = 'mock'
149-
150135
else:
151-
sys.exit('verspec '+args.verspec+' not supported')
152-
153-
154-
dirout = args.basedir_out+'/LSScats/'+version+'/blinded/'
155-
136+
sys.exit('verspec ' + args.verspec + ' not supported')
156137

138+
dirout = args.basedir_out + '/LSScats/' + version + '/blinded/'
157139

140+
if root and (not os.path.exists(dirout)):
141+
os.makedirs(dirout)
142+
print('made '+dirout)
158143

159144
tp2z = {'LRG':0.8,'ELG':1.1,'QSO':1.6}
160145
tp2bias = {'LRG':2.,'ELG':1.3,'QSO':2.3}
161-
ztp = tp2z[args.type]
162-
bias = tp2bias[args.type]
163-
164146

165147
if root:
166-
if not os.path.exists(dirout):
167-
os.makedirs(dirout)
168-
print('made '+dirout)
169-
148+
ztp = tp2z[args.type]
149+
bias = tp2bias[args.type]
170150

171151
w0wa = np.loadtxt('/global/cfs/cdirs/desi/survey/catalogs/Y1/LSS/w0wa_initvalues_zeffcombined_1000realisations.txt')
172152

173153
if args.get_par_mode == 'random':
174-
#if args.type != 'LRG':
175-
# sys.exit('Only do LRG in random mode, read from LRG file for other tracers')
154+
if args.type != 'LRG':
155+
sys.exit('Only do LRG in random mode, read from LRG file for other tracers')
176156
ind = int(random()*1000)
177157
[w0_blind,wa_blind] = w0wa[ind]
178158

179-
if args.get_par_mode == 'from_file' and root:
180-
fn = LSSdir + 'filerow.txt'
181-
if not os.path.isfile(fn):
182-
ind_samp = int(random()*1000)
183-
fo = open(fn,'w')
184-
fo.write(str(ind_samp)+'\n')
185-
fo.close()
186-
ind = int(np.loadtxt(fn))
159+
if args.get_par_mode == 'from_file':
160+
hd = fitsio.read_header(dirout+ 'LRG_full.dat.fits',ext='LSS')
161+
ind = hd['FILEROW']
187162
[w0_blind,wa_blind] = w0wa[ind]
188163

189164
#choose f_shift to compensate shift in monopole amplitude
@@ -196,32 +171,26 @@
196171
DM_shift = cosmo_shift.comoving_angular_distance(ztp)
197172
DH_shift = 1./cosmo_shift.hubble_function(ztp)
198173

199-
200-
vol_fac = (DM_shift**2*DH_shift)/(DM_fid**2*DH_fid)
174+
vol_fac = (DM_shift**2 * DH_shift) / (DM_fid**2 * DH_fid)
201175

202176
#a, b, c for quadratic formula
203-
a = 0.2/bias**2.
204-
b = 2/(3*bias)
205-
c = 1-(1+0.2*(args.fiducial_f/bias)**2.+2/3*args.fiducial_f/bias)/vol_fac
206-
207-
f_shift = (-b+np.sqrt(b**2.-4.*a*c))/(2*a)
208-
209-
dfper = (f_shift-args.fiducial_f)/args.fiducial_f
177+
a = 0.2 / bias**2
178+
b = 2 / (3 * bias)
179+
c = 1 - (1 + 0.2 * (args.fiducial_f / bias)**2. + 2/3 * args.fiducial_f / bias) / vol_fac
210180

181+
f_shift = (-b + np.sqrt(b**2. - 4.*a*c))/(2*a)
182+
dfper = (f_shift - args.fiducial_f)/args.fiducial_f
211183
maxfper = 0.1
212184
if abs(dfper) > maxfper:
213185
dfper = maxfper*dfper/abs(dfper)
214186
f_shift = (1+dfper)*args.fiducial_f
215-
216187
fgrowth_blind = f_shift
217188

218-
219189
#if args.reg_md == 'NS':
220190
regl = ['_S','_N']
221191
#if args.reg_md == 'GC':
222192
gcl = ['_SGC','_NGC']
223193

224-
225194
fb_in = dirin+type+notqso
226195
fcr_in = fb_in+'_1_full.ran.fits'
227196
fcd_in = fb_in+'_full.dat.fits'
@@ -238,7 +207,7 @@
238207
dz = 0.01
239208
#zmin = 0.01
240209
#zmax = 1.6
241-
210+
242211
if type[:3] == 'LRG':
243212
P0 = 10000
244213
#zmin = 0.4
@@ -286,7 +255,8 @@
286255
fd['WEIGHT_SYS'] *= wl
287256
common.write_LSS(fd,fcd_out)
288257

289-
if args.visnz == 'y':
258+
259+
if nzmd == 'mock':
290260
print('min/max of weights for nz:')
291261
print(np.min(wl),np.max(wl))
292262
fdin = fitsio.read(fcd_in)
@@ -295,8 +265,7 @@
295265
c = plt.hist(fd['Z'][gz],bins=100,range=(zmin,zmax),histtype='step',weights=fd['WEIGHT_SYS'][gz],label='blinded+reweight')
296266
plt.legend()
297267
plt.show()
298-
299-
268+
300269

301270
if args.type == 'LRG':
302271
hdul = fits.open(fcd_out,mode='update')
@@ -305,12 +274,10 @@
305274
hdtest = fitsio.read_header(dirout+ 'LRG_full.dat.fits', ext='LSS')['FILEROW']
306275
if hdtest != ind:
307276
sys.exit('ERROR writing/reading row from blind file')
308-
309-
310277

311278

312279
if args.mkclusdat == 'y':
313-
ct.mkclusdat(dirout+type+notqso,tp=type,dchi2=dchi2,tsnrcut=tsnrcut,zmin=zmin,zmax=zmax)
280+
ct.mkclusdat(dirout + type + notqso, tp=type, dchi2=dchi2, tsnrcut=tsnrcut, zmin=zmin, zmax=zmax)
314281

315282

316283
if args.mkclusran == 'y':
@@ -327,27 +294,24 @@
327294
ranfm = dirout+args.type+notqso+reg+'_'+str(rannum-1)+'_clustering.ran.fits'
328295
os.system('mv '+ranf+' '+ranfm)
329296

330-
reg_md = args.reg_md
297+
if args.split_GC == 'y':
298+
fb = dirout+args.type+notqso+'_'
299+
ct.clusNStoGC(fb,args.maxr-args.minr)
331300

332-
if args.split_GC == 'y' and root:
333-
fb = dirout+args.type+notqso+'_'
334-
ct.clusNStoGC(fb,args.maxr-args.minr)
335-
336-
sys.stdout.flush()
301+
sys.stdout.flush()
337302

338303
if args.dorecon == 'y':
339-
nran = args.maxr-args.minr
340-
304+
nran = args.maxr - args.minr
305+
306+
if root: print('on est la')
307+
341308
distance = TabulatedDESI().comoving_radial_distance
342309

343310
f, bias = rectools.get_f_bias(args.type)
344311
from pyrecon import MultiGridReconstruction
345-
Reconstruction = MultiGridReconstruction
346-
347-
setup_logging()
312+
Reconstruction = MultiGridReconstruction
348313

349-
350-
if reg_md == 'NS':
314+
if args.reg_md == 'NS':
351315
regions = ['N','S']
352316
else:
353317
regions = ['NGC','SGC']
@@ -358,22 +322,86 @@
358322
randoms_fn = catalog_fn(**catalog_kwargs, cat_dir=dirout, name='randoms')
359323
data_rec_fn = catalog_fn(**catalog_kwargs, cat_dir=dirout, rec_type='MGrsd', name='data')
360324
randoms_rec_fn = catalog_fn(**catalog_kwargs, cat_dir=dirout, rec_type='MGrsd', name='randoms')
361-
rectools.run_reconstruction(Reconstruction, distance, data_fn, randoms_fn, data_rec_fn, randoms_rec_fn, f=f, bias=bias, convention='rsd', dtype='f8', zlim=(zmin, zmax),mpicomm=mpicomm)
325+
rectools.run_reconstruction(Reconstruction, distance, data_fn, randoms_fn, data_rec_fn, randoms_rec_fn, f=f, bias=bias, convention='rsd', dtype='f8', zlim=(zmin, zmax), mpicomm=mpicomm)
326+
327+
if root and (args.rsdblind == 'y'):
362328

363-
if args.rsdblind == 'y' and root:
364-
if reg_md == 'NS':
329+
if root: print('on est ici')
330+
331+
if args.reg_md == 'NS':
365332
cl = regl
366-
if reg_md == 'GC':
333+
if args.reg_md == 'GC':
367334
cl = gcl
368335
for reg in cl:
369-
fnd = dirout+type+notqso+reg+'_clustering.dat.fits'
370-
fndr = dirout+type+notqso+reg+'_clustering.MGrsd.dat.fits'
336+
fnd = dirout + type + notqso + reg + '_clustering.dat.fits'
337+
fndr = dirout + type + notqso + reg + '_clustering.MGrsd.dat.fits'
371338
data = Table(fitsio.read(fnd))
372339
data_real = Table(fitsio.read(fndr))
373340

374341
out_file = fnd
375-
blind.apply_zshift_RSD(data,data_real,out_file,
342+
blind.apply_zshift_RSD(data, data_real, out_file,
376343
fgrowth_fid=args.fiducial_f,
377344
fgrowth_blind=fgrowth_blind)#,
378345
#comments=f"f_blind: {fgrowth_blind}, w0_blind: {w0_blind}, wa_blind: {wa_blind}")
379346

347+
if args.fnlblind == 'y':
348+
from mockfactory.blinding import get_cosmo_blind, CutskyCatalogBlinding
349+
350+
if root: print('on est ici')
351+
352+
if root:
353+
f_blind = fgrowth_blind
354+
# generate blinding value from the choosen index above
355+
np.random.seed(ind)
356+
fnl_blind = np.random.uniform(low=-15, high=15, size=1)[0]
357+
if not root:
358+
w0_blind, wa_blind, f_blind, fnl_blind = None, None, None, None
359+
w0_blind = mpicomm.bcast(w0_blind, root=0)
360+
wa_blind = mpicomm.bcast(wa_blind, root=0)
361+
f_blind = mpicomm.bcast(f_blind, root=0)
362+
fnl_blind = mpicomm.bcast(fnl_blind, root=0)
363+
364+
# collect effective redshift and bias for the considered tracer
365+
zeff = tp2z[args.type]
366+
bias = tp2bias[args.type]
367+
368+
# build blinding cosmology
369+
cosmo_blind = get_cosmo_blind('DESI', z=zeff)
370+
cosmo_blind.params['w0_fld'] = w0_blind
371+
cosmo_blind.params['wa_fld'] = wa_blind
372+
cosmo_blind._derived['f'] = f_blind
373+
cosmo_blind._derived['fnl'] = fnl_blind # on fixe la valeur pour de bon
374+
blinding = CutskyCatalogBlinding(cosmo_fid='DESI', cosmo_blind=cosmo_blind, bias=bias, z=zeff, position_type='rdz', mpicomm=mpicomm, mpiroot=0)
375+
376+
# loop over the different region of the sky
377+
if args.reg_md == 'NS':
378+
cl = regl
379+
if args.reg_md == 'GC':
380+
cl = gcl
381+
for reg in cl:
382+
# path of data and randoms:
383+
catalog_kwargs = dict(tracer=args.type, region=region, ctype='clustering', nrandoms=nran)
384+
data_fn = catalog_fn(**catalog_kwargs, cat_dir=dirout, name='data')
385+
randoms_fn = catalog_fn(**catalog_kwargs, cat_dir=dirout, name='randoms')
386+
if np.ndim(randoms_fn) == 0: randoms_fn = [randoms_fn]
387+
388+
data_positions, data_weights = None, None
389+
randoms_positions, randoms_weights = None, None
390+
if root:
391+
print('Loading {}.'.format(data_fn))
392+
data = Table.read(data_fn)
393+
data_positions, data_weights = [data['RA'], data['DEC'], data['Z']], data['WEIGHT']
394+
395+
print('Loading {}'.format(randoms_fn))
396+
randoms = vstack([Table.read(fn) for fn in randoms_fn])
397+
randoms_positions, randoms_weights = [randoms['RA'], randoms['DEC'], randoms['Z']], randoms['WEIGHT']
398+
399+
# add fnl blinding weight to the data weight
400+
new_data_weights = blinding.png(data_positions, data_weights=data_weights,
401+
randoms_positions=randoms_positions, randoms_weights=randoms_weights,
402+
method='data_weights', shotnoise_correction=True)
403+
404+
# overwrite the data!
405+
if root:
406+
data['WEIGHT'] = new_data_weights
407+
common.write_LSS(data, data_fn)

0 commit comments

Comments
 (0)