Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 24 additions & 23 deletions code/bump_hunt.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import torch
import random
import inspect
import awkward0
import numpy as np
import pandas as pd
import mplhep as hep
Expand All @@ -20,15 +19,10 @@
from glob import glob
from pathlib import Path
from sklearn import metrics
from torch.nn import MSELoss
from scipy.optimize import curve_fit
from torch.utils.data import random_split
from matplotlib.backends.backend_pdf import PdfPages
from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataListLoader

import models.models as models
import models.emd_models as emd_models
from util.loss_util import LossFunction
from datagen.graph_data_gae import GraphDataset, collate
from util.plot_util import loss_distr, plot_reco_difference
Expand Down Expand Up @@ -182,14 +176,13 @@ def fit_function(x, p0, p1, p2, p3, p4):
bh.plot_bump(data=outlier_mass, bkg=nonoutlier_mass, filename=save_name+'.pdf', x_label=r'$m_{jj}$ [GeV]')
bh.plot_stat(show_Pval=True, filename=save_name+'_stat.pdf')

def process(data_loader, num_events, model_path, model, loss_ftn_obj, latent_dim, features):
def process(data_loader, model_path, model, loss_ftn_obj, latent_dim, features):
"""
Use the specified model to determine the reconstruction loss of each sample.
Also calculate the invariant mass of the jets.

Args:
data_loader (torch.data.DataLoader): pytorch dataloader for loading in black boxes
num_events (int): how many events we're processing
model_path (str): path to saved model
model (str): name of model class
loss_ftn_obj (LossFunction): see loss_util.py
Expand All @@ -202,6 +195,7 @@ def process(data_loader, num_events, model_path, model, loss_ftn_obj, latent_dim
"""

# load corresponding model
print("Loading model")
if model == 'MetaLayerGAE':
model = models.GNNAutoEncoder()
else:
Expand All @@ -218,7 +212,9 @@ def process(data_loader, num_events, model_path, model, loss_ftn_obj, latent_dim
jets_proc_data = []
input_fts = []
reco_fts = []


print("Begin processing")

event = 0
# for each event in the dataset calculate the loss and inv mass for the leading 2 jets
with torch.no_grad():
Expand Down Expand Up @@ -249,7 +245,7 @@ def process(data_loader, num_events, model_path, model, loss_ftn_obj, latent_dim
for ib in torch.unique(batch):
if loss_ftn_obj.name == 'vae_loss':
losses[ib] = loss_ftn_obj.loss_ftn(jets_rec[batch==ib], jets_x[batch==ib], mu, log_var)
elif loss_ftn_obj.name == 'emd_loss':
elif loss_ftn_obj.name == 'emd_loss' or loss_ftn_obj.name == 'chamfer_loss' or loss_ftn_obj.name == 'hungarian_loss':
losses[ib] = loss_ftn_obj.loss_ftn(jets_rec[batch==ib], jets_x[batch==ib], torch.tensor(0).repeat(jets_rec[batch==ib].shape[0]))
else:
losses[ib] = loss_ftn_obj.loss_ftn(jets_rec[batch==ib], jets_x[batch==ib])
Expand All @@ -272,16 +268,15 @@ def process(data_loader, num_events, model_path, model, loss_ftn_obj, latent_dim
# return pytorch tensors
return torch.cat(jets_proc_data), torch.cat(input_fts), torch.cat(reco_fts)

def bump_hunt(df, cuts, model_fname, model, bb, save_path):
def bump_hunt(df, cuts, loss_name, bb, save_path):
"""
Loops and makes multiple cuts on the jet losses, and generates a graph for each cut by
delegating to make_bump_graph().

Args:
df (pd.DataFrame): output of process() transformed into datafram; has loss and mass of jets per event
cuts (list of floats): all the percentages to perform a cut on the loss
model_fname (str): name of saved model
model (str): name of model class
loss_name (str): name of loss function
bb (str): which black box the bump hunt is being performed on (e.g. 'bb1')
"""
losses = np.concatenate([df['loss1'], df['loss2']])
Expand Down Expand Up @@ -330,15 +325,19 @@ def bump_hunt(df, cuts, model_fname, model, bb, save_path):
bins = np.linspace(0, 1800, 51)
make_bump_graph(nonoutlier_m2_mass, outlier_m2_mass, x_lab, mj2_graph_name, bins)

if bb == 'rnd': # plot roc for rnd set
if 'rnd' in bb: # plot roc for rnd set
df['loss_sum'] = df['loss1']+df['loss2']
df['loss_min'] = np.minimum(df['loss1'],df['loss2'])
df['loss_max'] = np.maximum(df['loss1'],df['loss2'])

if model_fname != 'GNN_AE_EdgeConv_Finished':
if loss_name == 'chamfer_loss':
loss_type = '$D^{NN}$'
else:
elif loss_name == 'MSE':
loss_type = 'MSE'
elif loss_name == 'hungarian_loss':
loss_type = 'Hung.'
elif loss_name == 'emd_loss':
loss_type = 'emd-nn'

plt.figure(figsize=(8,6))
plt.style.use(hep.style.CMS)
Expand Down Expand Up @@ -391,19 +390,20 @@ def get_df(proc_jets):
return df

# read in dataset
bb_name = ["bb0", "bb1", "bb2", "bb3", "rnd"][box_num]
bb_name = ["bb0_xyz_pyg2", "bb1_xyz_pyg2", "bb2_xyz_pyg2", "bb3_xyz_pyg2", "rnd_xyz_pyg2"][box_num]
print("Plotting %s"%bb_name)

save_path = osp.join(output_dir,model_fname,'bump_hunt',bb_name)
Path(save_path).mkdir(parents=True,exist_ok=True) # make a subfolder

if not osp.isfile(osp.join(save_path,'df.pkl')) or overwrite:
print("Processing jet losses")
gdata = GraphDataset('/anomalyvol/data/lead_2/tiny', n_events=num_events, bb=box_num, features=features)
# gdata = GraphDataset('/anomalyvol/data/lead_2/%s/'%bb_name, n_events=num_events, bb=box_num, features=features)
print("Processing new jet losses")
# gdata = GraphDataset('/anomalyvol/data/lead_2/tiny', n_events=num_events, bb=box_num, features=features)
gdata = GraphDataset('/anomalyvol/data/lead_2/%s/'%bb_name, n_events=num_events, bb=box_num, features=features)
print("Dataset loaded")
bb_loader = DataListLoader(gdata, batch_size=1, pin_memory=True, shuffle=False)
bb_loader.collate_fn = collate
proc_jets, input_fts, reco_fts = process(bb_loader, num_events, model_path, model, loss_ftn_obj, latent_dim, features)
proc_jets, input_fts, reco_fts = process(bb_loader, model_path, model, loss_ftn_obj, latent_dim, features)
df = get_df(proc_jets)
df.to_pickle(osp.join(save_path,'df.pkl'))
torch.save(input_fts, osp.join(save_path,'input_fts.pt'))
Expand All @@ -414,7 +414,7 @@ def get_df(proc_jets):
input_fts = torch.load(osp.join(save_path,'input_fts.pt'))
reco_fts = torch.load(osp.join(save_path,'reco_fts.pt'))
plot_reco_difference(input_fts, reco_fts, model_fname, save_path)
bump_hunt(df, cuts, model_fname, model, bb_name, save_path)
bump_hunt(df, cuts, args.loss, bb_name, save_path)

if __name__ == "__main__":

Expand All @@ -427,7 +427,8 @@ def get_df(proc_jets):
parser.add_argument("--overwrite", action='store_true', help="Toggle overwrite of pkl. Default False.", default=False, required=False)
parser.add_argument("--num-events", type=int, help="How many events to process (multiple of 100). Default 1mil", default=1000000, required=False)
parser.add_argument("--latent-dim", type=int, help="How many units for the latent space (def=2)", default=2, required=False)
parser.add_argument("--loss", choices=["chamfer_loss","emd_loss","vae_loss","mse"], help="loss function", required=True)
parser.add_argument('--loss', choices=[m for m in dir(LossFunction) if not m.startswith('__')],
help='loss function', required=True)
parser.add_argument("--box-num", type=int, help="0=QCD-background; 1=bb1; 2=bb2; 4=rnd", required=True)
parser.add_argument("--features", choices=['xyz','relptetaphi'], help="Generate (px,py,pz) or relative (pt,eta,phi)", required=True)
args = parser.parse_args()
Expand Down
20 changes: 0 additions & 20 deletions code/util/loss_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,23 +152,3 @@ def hungarian_loss(self, x, y, batch):

total_loss = torch.mean(torch.stack(list(losses)))
return total_loss

def hungarian_loss(self, x, y):
"""heavily based on the the function found in
https://github.com/Cyanogenoid/dspn/blob/be3703b470ead46d76b70b4fed656c2e5343aff6/dspn/utils.py#L6-L23"""
# x and y shape :: (n, c, s)
x, y = outer(x, y)
# squared_error shape :: (n, s, s)
squared_error = F.smooth_l1_loss(x, y.expand_as(x), reduction="none").mean(1)

squared_error_np = squared_error.detach().cpu().numpy()
indices = map(hungarian_loss_per_sample, squared_error_np)
losses = [
sample[row_idx, col_idx].mean()
for sample, (row_idx, col_idx) in zip(squared_error, indices)
]
total_loss = torch.mean(torch.stack(list(losses)))
return total_loss