Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 73 additions & 9 deletions flamedisx/non_asymptotic_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@ def __init__(self, likelihood):
self.likelihood = likelihood

def __call__(self, mu_test, signal_source_name, guess_dict,
asymptotic=False):
asymptotic=False, fix_dict_param=None):
# To fix the signal RM in the conditional fit
fix_dict = {f'{signal_source_name}_rate_multiplier': tf.cast(mu_test, fd.float_type())}
if fix_dict_param is not None:
fix_dict = {**fix_dict, **fix_dict_param}

guess_dict_nuisance = guess_dict.copy()
guess_dict_nuisance.pop(f'{signal_source_name}_rate_multiplier')
Expand Down Expand Up @@ -180,7 +182,8 @@ def __init__(
gaussian_constraint_widths: ty.Dict[str, float] = None,
sample_other_constraints: ty.Dict[str, ty.Callable] = None,
likelihood=None,
ntoys=1000):
ntoys=1000,
fix_dict=None):

if gaussian_constraint_widths is None:
gaussian_constraint_widths = dict()
Expand All @@ -199,14 +202,16 @@ def __init__(
self.expected_background_counts = expected_background_counts
self.gaussian_constraint_widths = gaussian_constraint_widths
self.sample_other_constraints = sample_other_constraints
self.fix_dict_param = fix_dict

def run_routine(self, mus_test=None, save_fits=False,
observed_data=None,
observed_test_stats=None,
generate_B_toys=False,
simulate_dict_B=None, toy_data_B=None, constraint_extra_args_B=None,
toy_batch=0,
asymptotic=False):
asymptotic=False,
vary_signal_dict=None):
"""If observed_data is passed, evaluate observed test statistics. Otherwise,
obtain test statistic distributions (for both S+B and B-only).

Expand Down Expand Up @@ -301,7 +306,8 @@ def run_routine(self, mus_test=None, save_fits=False,
self.toy_test_statistic_dist(test_stat_dists_SB, test_stat_dists_B,
test_stat_dists_SB_disco,
mu_test, signal_source, likelihood,
save_fits=save_fits)
save_fits=save_fits,
vary_signal_dict=vary_signal_dict)

if observed_data is not None:
observed_test_stats_collection[signal_source] = observed_test_stats
Expand All @@ -323,6 +329,11 @@ def sample_data_constraints(self, mu_test, signal_source_name, likelihood):
"""
simulate_dict = dict()
constraint_extra_args = dict()
if self.fix_dict_param is None:
fix_dict_param = dict()
else:
fix_dict_param = self.fix_dict_param

for background_source in self.background_source_names:
# Case where we use the conditional best fits as constraint centers and simulated values
if self.observed_test_stats is not None:
Expand All @@ -349,6 +360,55 @@ def sample_data_constraints(self, mu_test, signal_source_name, likelihood):
simulate_dict[f'{background_source}_rate_multiplier'] = tf.cast(expected_background_counts, fd.float_type())
simulate_dict[f'{signal_source_name}_rate_multiplier'] = tf.cast(mu_test, fd.float_type())

for param_name in likelihood.param_names:
# For all other parameters
if '_rate_multiplier' in param_name:
continue

# Initialize default parameter and bounds.
try:
param_expect = likelihood.param_defaults[param_name]
except Exception:
raise RuntimeError(f"Default value of parameter {param_name} not found")
try:
param_bounds = likelihood.default_bounds[param_name]
except Exception:
raise RuntimeError(f"Bounds of parameter {param_name} not found")

# Case of the conditional best fits
if self.observed_test_stats is not None:
try:
param_expect = conditional_bfs_observed[mu_test][param_name]
except Exception:
raise RuntimeError(f"Could not find observed conditional best fits for parameter {param_name}")

# Sample constraint centers
""" if param_name in self.gaussian_constraint_widths.keys():
# Given the parameter center, use {gaussian_constraint_widths} to get draw
if param_name not in fix_dict_param.keys():
draw = stats.norm.rvs(loc=param_expect,
scale=self.gaussian_constraint_widths[param_name])
else:
draw = fix_dict_param[param_name]
constraint_extra_args[param_name] = tf.cast(draw, fd.float_type()) """

if param_name in self.sample_other_constraints.keys():
# Given the parameter center, use {sample_other_constraint} as draw
draw = self.sample_other_constraints[param_name](param_expect)
constraint_extra_args[param_name] = tf.cast(draw, fd.float_type())
else:
# If not provided, use 10% of parameter bound as gaussian width
if param_name not in fix_dict_param.keys():
param_range = tf.math.reduce_max(param_bounds) - tf.math.reduce_min(param_bounds)
draw = stats.norm.rvs(loc=param_expect, scale=0.1*param_range)
else:
draw = fix_dict_param[param_name]
constraint_extra_args[param_name] = tf.cast(draw, fd.float_type())

if param_name not in fix_dict_param.keys():
simulate_dict[param_name] = tf.cast(param_expect, fd.float_type())


if self.observed_test_stats is not None:
conditional_bfs_observed = self.observed_test_stats[signal_source_name].conditional_best_fits[mu_test]
non_rate_params_added = []
Expand All @@ -368,7 +428,7 @@ def sample_data_constraints(self, mu_test, signal_source_name, likelihood):
def toy_test_statistic_dist(self, test_stat_dists_SB, test_stat_dists_B,
test_stat_dists_SB_disco,
mu_test, signal_source_name, likelihood,
save_fits=False):
save_fits=False, vary_signal_dict=None):
"""Internal function to get test statistic distribution.
"""
ts_values_SB = []
Expand All @@ -382,8 +442,12 @@ def toy_test_statistic_dist(self, test_stat_dists_SB, test_stat_dists_B,

# Loop over toys
for toy in tqdm(range(self.ntoys), desc='Doing toys'):
if vary_signal_dict is not None:
mu_sim = sps.norm.rvs(loc=mu_test, scale=vary_signal_dict[signal_source_name])
else:
mu_sim = mu_test
simulate_dict_SB, toy_data_SB, constraint_extra_args_SB = \
self.sample_data_constraints(mu_test, signal_source_name, likelihood)
self.sample_data_constraints(mu_sim, signal_source_name, likelihood)

# S+B toys

Expand Down Expand Up @@ -413,8 +477,8 @@ def toy_test_statistic_dist(self, test_stat_dists_SB, test_stat_dists_B,
if value < 0.1:
guess_dict_SB[key] = 0.1
# Evaluate test statistics
ts_result_SB = test_statistic_SB(mu_test, signal_source_name, guess_dict_SB)
ts_result_SB_disco = test_statistic_SB(0., signal_source_name, guess_dict_SB)
ts_result_SB = test_statistic_SB(mu_test, signal_source_name, guess_dict_SB, fix_dict_param = self.fix_dict_param)
ts_result_SB_disco = test_statistic_SB(0., signal_source_name, guess_dict_SB, fix_dict_param = self.fix_dict_param)
# Save test statistics, and possibly fits
ts_values_SB.append(ts_result_SB[0])
ts_values_SB_disco.append(ts_result_SB_disco[0])
Expand Down Expand Up @@ -457,7 +521,7 @@ def toy_test_statistic_dist(self, test_stat_dists_SB, test_stat_dists_B,
# Create test statistic
test_statistic_B = self.test_statistic(likelihood)
# Evaluate test statistic
ts_result_B = test_statistic_B(mu_test, signal_source_name, guess_dict_B)
ts_result_B = test_statistic_B(mu_test, signal_source_name, guess_dict_B, fix_dict_param = self.fix_dict_param)
# Save test statistic, and possibly fits
ts_values_B.append(ts_result_B[0])
if save_fits:
Expand Down
2 changes: 1 addition & 1 deletion flamedisx/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def simulate(self, n_events, fix_truth=None, full_annotate=False,
# TODO: all other arguments are ignored, they make no sense
# for this source. Should we warn about this? Remove them from def?

assert len(self.defaults) == 1
#assert len(self.defaults) == 1

template_weights = tfp.math.batch_interp_regular_1d_grid(
x=params[next(iter(self.defaults))],
Expand Down
Loading