diff --git a/flamedisx/non_asymptotic_inference.py b/flamedisx/non_asymptotic_inference.py index ea916968..fc0787ea 100644 --- a/flamedisx/non_asymptotic_inference.py +++ b/flamedisx/non_asymptotic_inference.py @@ -23,9 +23,11 @@ def __init__(self, likelihood): self.likelihood = likelihood def __call__(self, mu_test, signal_source_name, guess_dict, - asymptotic=False): + asymptotic=False, fix_dict_param=None): # To fix the signal RM in the conditional fit fix_dict = {f'{signal_source_name}_rate_multiplier': tf.cast(mu_test, fd.float_type())} + if fix_dict_param is not None: + fix_dict = {**fix_dict, **fix_dict_param} guess_dict_nuisance = guess_dict.copy() guess_dict_nuisance.pop(f'{signal_source_name}_rate_multiplier') @@ -180,7 +182,8 @@ def __init__( gaussian_constraint_widths: ty.Dict[str, float] = None, sample_other_constraints: ty.Dict[str, ty.Callable] = None, likelihood=None, - ntoys=1000): + ntoys=1000, + fix_dict=None): if gaussian_constraint_widths is None: gaussian_constraint_widths = dict() @@ -199,6 +202,7 @@ def __init__( self.expected_background_counts = expected_background_counts self.gaussian_constraint_widths = gaussian_constraint_widths self.sample_other_constraints = sample_other_constraints + self.fix_dict_param = fix_dict def run_routine(self, mus_test=None, save_fits=False, observed_data=None, @@ -206,7 +210,8 @@ def run_routine(self, mus_test=None, save_fits=False, generate_B_toys=False, simulate_dict_B=None, toy_data_B=None, constraint_extra_args_B=None, toy_batch=0, - asymptotic=False): + asymptotic=False, + vary_signal_dict=None): """If observed_data is passed, evaluate observed test statistics. Otherwise, obtain test statistic distributions (for both S+B and B-only). @@ -301,7 +306,8 @@ def run_routine(self, mus_test=None, save_fits=False, self.toy_test_statistic_dist(test_stat_dists_SB, test_stat_dists_B, test_stat_dists_SB_disco, mu_test, signal_source, likelihood, - save_fits=save_fits) + save_fits=save_fits, + vary_signal_dict=vary_signal_dict) if observed_data is not None: observed_test_stats_collection[signal_source] = observed_test_stats @@ -323,6 +329,11 @@ def sample_data_constraints(self, mu_test, signal_source_name, likelihood): """ simulate_dict = dict() constraint_extra_args = dict() + if self.fix_dict_param is None: + fix_dict_param = dict() + else: + fix_dict_param = self.fix_dict_param + for background_source in self.background_source_names: # Case where we use the conditional best fits as constraint centers and simulated values if self.observed_test_stats is not None: @@ -349,6 +360,55 @@ def sample_data_constraints(self, mu_test, signal_source_name, likelihood): simulate_dict[f'{background_source}_rate_multiplier'] = tf.cast(expected_background_counts, fd.float_type()) simulate_dict[f'{signal_source_name}_rate_multiplier'] = tf.cast(mu_test, fd.float_type()) + for param_name in likelihood.param_names: + # For all other parameters + if '_rate_multiplier' in param_name: + continue + + # Initialize default parameter and bounds. + try: + param_expect = likelihood.param_defaults[param_name] + except Exception: + raise RuntimeError(f"Default value of parameter {param_name} not found") + try: + param_bounds = likelihood.default_bounds[param_name] + except Exception: + raise RuntimeError(f"Bounds of parameter {param_name} not found") + + # Case of the conditional best fits + if self.observed_test_stats is not None: + try: + param_expect = conditional_bfs_observed[mu_test][param_name] + except Exception: + raise RuntimeError(f"Could not find observed conditional best fits for parameter {param_name}") + + # Sample constraint centers +""" if param_name in self.gaussian_constraint_widths.keys(): + # Given the parameter center, use {gaussian_constraint_widths} to get draw + if param_name not in fix_dict_param.keys(): + draw = stats.norm.rvs(loc=param_expect, + scale=self.gaussian_constraint_widths[param_name]) + else: + draw = fix_dict_param[param_name] + constraint_extra_args[param_name] = tf.cast(draw, fd.float_type()) """ + + if param_name in self.sample_other_constraints.keys(): + # Given the parameter center, use {sample_other_constraint} as draw + draw = self.sample_other_constraints[param_name](param_expect) + constraint_extra_args[param_name] = tf.cast(draw, fd.float_type()) + else: + # If not provided, use 10% of parameter bound as gaussian width + if param_name not in fix_dict_param.keys(): + param_range = tf.math.reduce_max(param_bounds) - tf.math.reduce_min(param_bounds) + draw = stats.norm.rvs(loc=param_expect, scale=0.1*param_range) + else: + draw = fix_dict_param[param_name] + constraint_extra_args[param_name] = tf.cast(draw, fd.float_type()) + + if param_name not in fix_dict_param.keys(): + simulate_dict[param_name] = tf.cast(param_expect, fd.float_type()) + + if self.observed_test_stats is not None: conditional_bfs_observed = self.observed_test_stats[signal_source_name].conditional_best_fits[mu_test] non_rate_params_added = [] @@ -368,7 +428,7 @@ def sample_data_constraints(self, mu_test, signal_source_name, likelihood): def toy_test_statistic_dist(self, test_stat_dists_SB, test_stat_dists_B, test_stat_dists_SB_disco, mu_test, signal_source_name, likelihood, - save_fits=False): + save_fits=False, vary_signal_dict=None): """Internal function to get test statistic distribution. """ ts_values_SB = [] @@ -382,8 +442,12 @@ def toy_test_statistic_dist(self, test_stat_dists_SB, test_stat_dists_B, # Loop over toys for toy in tqdm(range(self.ntoys), desc='Doing toys'): + if vary_signal_dict is not None: + mu_sim = sps.norm.rvs(loc=mu_test, scale=vary_signal_dict[signal_source_name]) + else: + mu_sim = mu_test simulate_dict_SB, toy_data_SB, constraint_extra_args_SB = \ - self.sample_data_constraints(mu_test, signal_source_name, likelihood) + self.sample_data_constraints(mu_sim, signal_source_name, likelihood) # S+B toys @@ -413,8 +477,8 @@ def toy_test_statistic_dist(self, test_stat_dists_SB, test_stat_dists_B, if value < 0.1: guess_dict_SB[key] = 0.1 # Evaluate test statistics - ts_result_SB = test_statistic_SB(mu_test, signal_source_name, guess_dict_SB) - ts_result_SB_disco = test_statistic_SB(0., signal_source_name, guess_dict_SB) + ts_result_SB = test_statistic_SB(mu_test, signal_source_name, guess_dict_SB, fix_dict_param = self.fix_dict_param) + ts_result_SB_disco = test_statistic_SB(0., signal_source_name, guess_dict_SB, fix_dict_param = self.fix_dict_param) # Save test statistics, and possibly fits ts_values_SB.append(ts_result_SB[0]) ts_values_SB_disco.append(ts_result_SB_disco[0]) @@ -457,7 +521,7 @@ def toy_test_statistic_dist(self, test_stat_dists_SB, test_stat_dists_B, # Create test statistic test_statistic_B = self.test_statistic(likelihood) # Evaluate test statistic - ts_result_B = test_statistic_B(mu_test, signal_source_name, guess_dict_B) + ts_result_B = test_statistic_B(mu_test, signal_source_name, guess_dict_B, fix_dict_param = self.fix_dict_param) # Save test statistic, and possibly fits ts_values_B.append(ts_result_B[0]) if save_fits: diff --git a/flamedisx/templates.py b/flamedisx/templates.py index 6aa15b32..a6a18239 100644 --- a/flamedisx/templates.py +++ b/flamedisx/templates.py @@ -402,7 +402,7 @@ def simulate(self, n_events, fix_truth=None, full_annotate=False, # TODO: all other arguments are ignored, they make no sense # for this source. Should we warn about this? Remove them from def? - assert len(self.defaults) == 1 + #assert len(self.defaults) == 1 template_weights = tfp.math.batch_interp_regular_1d_grid( x=params[next(iter(self.defaults))],