diff --git a/flamedisx/__init__.py b/flamedisx/__init__.py index fa5169098..3b89e39a2 100644 --- a/flamedisx/__init__.py +++ b/flamedisx/__init__.py @@ -41,3 +41,6 @@ # Custom TFP files # Access through fd.tfp_files.xxx from . import tfp_files + +# TEMPORARY,I HOPE +from . import tfbspline diff --git a/flamedisx/inference.py b/flamedisx/inference.py index cfeb9ca82..67541e24d 100644 --- a/flamedisx/inference.py +++ b/flamedisx/inference.py @@ -176,10 +176,13 @@ def _dict_to_array(self, x: dict) -> np.array: def _array_to_dict(self, x: ty.Union[np.ndarray, tf.Tensor]) -> dict: """Convert from array/tensor to {parameter: value} dictionary""" + x = tf.cast(x, fd.float_type()) assert isinstance(x, (np.ndarray, tf.Tensor)) assert len(x) == len(self.arg_names) - return {k: x[i] - for i, k in enumerate(self.arg_names)} + param_dict = dict() + for i, k in enumerate(self.arg_names): + param_dict[k] = tf.gather(x, i) + return param_dict def normalize(self, x: ty.Union[dict, np.ndarray], diff --git a/flamedisx/likelihood.py b/flamedisx/likelihood.py index a2198f383..ff8c8067d 100644 --- a/flamedisx/likelihood.py +++ b/flamedisx/likelihood.py @@ -271,6 +271,8 @@ def set_data(self, UserWarning) for s in self.sources.values(): s.set_data(None) + s.n_batches = 0 + self.batch_info = None return batch_info = np.zeros((len(self.dsetnames), 3), dtype=int) @@ -340,7 +342,8 @@ def set_data(self, np.concatenate([[0], stop_idx[:-1]]), stop_idx]) - def simulate(self, fix_truth=None, **params): + def simulate(self, fix_truth=None, alter_source_mus=False, + **params): """Simulate events from sources. """ params = self.prepare_params(params, free_all_rates=True) @@ -352,6 +355,8 @@ def simulate(self, fix_truth=None, **params): rm = self._get_rate_mult(sname, params) mu = rm * s.mu_before_efficiencies( **self._filter_source_kwargs(params, sname)) + if alter_source_mus: + mu *= self.mu_estimators[sname](**self._filter_source_kwargs(params, sname)) # Simulate this many events from source n_to_sim = np.random.poisson(mu) if n_to_sim == 0: @@ -380,9 +385,6 @@ def log_likelihood(self, second_order=False, omit_grads=tuple(), **kwargs): params = self.prepare_params(kwargs) n_grads = len(self.param_defaults) - len(omit_grads) - ll = 0. - llgrad = np.zeros(n_grads, dtype=np.float64) - llgrad2 = np.zeros((n_grads, n_grads), dtype=np.float64) for dsetname in self.dsetnames: # Getting this from the batch_info tensor is much slower @@ -395,14 +397,18 @@ def log_likelihood(self, second_order=False, else: empty_batch = False + ll = {i_batch: 0. for i_batch in range(n_batches)} + llgrad = np.zeros(n_grads, dtype=np.float64) + llgrad2 = np.zeros((n_grads, n_grads), dtype=np.float64) + for i_batch in range(n_batches): # Iterating over tf.range seems much slower! if empty_batch: batch_data_tensor = None else: - batch_data_tensor = self.data_tensors[dsetname][i_batch] + batch_data_tensor = tf.gather(self.data_tensors[dsetname], i_batch) results = self._log_likelihood( - tf.constant(i_batch, dtype=fd.int_type()), + i_batch, dsetname=dsetname, data_tensor=batch_data_tensor, batch_info=self.batch_info, @@ -411,18 +417,18 @@ def log_likelihood(self, second_order=False, empty_batch=empty_batch, constraint_extra_args=self.constraint_extra_args, **params) - ll += results[0].numpy().astype(np.float64) + ll[i_batch] = results[0] if self.param_names: if results[1] is None: raise ValueError("TensorFlow returned None as gradient!") - llgrad += results[1].numpy().astype(np.float64) + llgrad += results[1] if second_order: - llgrad2 += results[2].numpy().astype(np.float64) + llgrad2 += results[2] if second_order: - return ll, llgrad, llgrad2 - return ll, llgrad, None + return np.sum(list(ll.values())), llgrad, llgrad2 + return np.sum(list(ll.values())), llgrad, None def minus2_ll(self, *, omit_grads=tuple(), **kwargs): result = self.log_likelihood(omit_grads=omit_grads, **kwargs) @@ -431,11 +437,6 @@ def minus2_ll(self, *, omit_grads=tuple(), **kwargs): return -2 * ll, -2 * grad, hess def prepare_params(self, kwargs, free_all_rates=False): - for k in kwargs: - if k not in self.param_defaults: - if k.endswith('_rate_multiplier') and free_all_rates: - continue - raise ValueError(f"Unknown parameter {k}") return {**self.param_defaults, **fd.values_to_constants(kwargs)} def _get_rate_mult(self, sname, kwargs): @@ -468,7 +469,9 @@ def mu(self, *, :param dataset_name: ... for just this dataset :param source_name: ... for just this source. You must provide either dsetname or source, since it makes no sense to - add events from multiple datasets + add events from multiple datasets. + For rate multipliers (always linear) add a 0 x r.m**2 term to give a 0 + hessian instead of None. """ kwargs = {**self.param_defaults, **kwargs} if dataset_name is None and source_name is None: @@ -481,8 +484,10 @@ def mu(self, *, if source_name is not None and sname != source_name: continue filtered_params = self._filter_source_kwargs(kwargs, sname) - mu += (self._get_rate_mult(sname, kwargs) - * self.mu_estimators[sname](**filtered_params)) + _rate_multiplier = self._get_rate_mult(sname, kwargs) + mu += (_rate_multiplier + * self.mu_estimators[sname](**filtered_params) + + tf.constant(0.,fd.float_type())*_rate_multiplier**2) return mu @tf.function @@ -521,10 +526,14 @@ def _log_likelihood(self, 0.) if dsetname == self.dsetnames[0]: if constraint_extra_args is None: - ll += self.log_constraint(**params_unstacked) + ll += tf.where(tf.equal(i_batch, tf.constant(0, dtype=fd.int_type())), + self.log_constraint(**params_unstacked), + 0.) else: kwargs = {**params_unstacked, **constraint_extra_args} - ll += self.log_constraint(**kwargs) + ll += tf.where(tf.equal(i_batch, tf.constant(0, dtype=fd.int_type())), + self.log_constraint(**kwargs), + 0.) # Autodifferentiation. This is why we use tensorflow: grad = tf.gradients(ll, grad_par_stack)[0] diff --git a/flamedisx/non_asymptotic_inference.py b/flamedisx/non_asymptotic_inference.py index 388750172..8fc741aa8 100644 --- a/flamedisx/non_asymptotic_inference.py +++ b/flamedisx/non_asymptotic_inference.py @@ -4,6 +4,8 @@ from tqdm.auto import tqdm import typing as ty +from copy import deepcopy + import tensorflow as tf export, __all__ = fd.exporter() @@ -22,15 +24,19 @@ def __init__(self, likelihood): def __call__(self, mu_test, signal_source_name, guess_dict): # To fix the signal RM in the conditional fit - fix_dict = {f'{signal_source_name}_rate_multiplier': mu_test} + fix_dict = {f'{signal_source_name}_rate_multiplier': tf.cast(mu_test, fd.float_type())} guess_dict_nuisance = guess_dict.copy() guess_dict_nuisance.pop(f'{signal_source_name}_rate_multiplier') # Conditional fit - bf_conditional = self.likelihood.bestfit(fix=fix_dict, guess=guess_dict_nuisance, suppress_warnings=True) + bf_conditional = self.likelihood.bestfit(fix=fix_dict, guess=guess_dict_nuisance, suppress_warnings=True, + allow_failure=True) + bf_conditional = {k: v.numpy() for k, v in bf_conditional.items()} # Uncnditional fit - bf_unconditional = self.likelihood.bestfit(guess=guess_dict, suppress_warnings=True) + bf_unconditional = self.likelihood.bestfit(guess=guess_dict, suppress_warnings=True, + allow_failure=True) + bf_unconditional = {k: v.numpy() for k, v in bf_unconditional.items()} # Return the test statistic, unconditional fit and conditional fit return self.evaluate(bf_unconditional, bf_conditional), bf_unconditional, bf_conditional @@ -135,9 +141,6 @@ class TSEvaluation(): - signal_source_names: tuple of names for signal sources (e.g. WIMPs of different masses) - background_source_names: tuple of names for background sources - - sources: dictionary {sourcename: class} of all signal and background source classes - - arguments: dictionary {sourcename: {kwarg1: value, ...}, ...}, for - passing keyword arguments to source constructors - expected_background_counts: dictionary of expected counts for background sources - gaussian_constraint_widths: dictionary giving the constraint width for all sources using Gaussian constraints for their rate nuisance parameters @@ -145,31 +148,19 @@ class TSEvaluation(): in the toys for any sources using non-Gaussian constraints for their rate nuisance parameters. Argument to the function will be either the prior expected counts, or the number of counts at the conditional MLE, depending on the mode - - rm_bounds: dictionary {sourcename: (lower, upper)} to set fit bounds on the rate multipliers - - log_constraint_fn: logarithm of the constraint function used in the likelihood. Any arguments - which aren't fit parameters, such as those determining constraint means for toys, will need - passing via the set_constraint_extra_args() function + - likelihood: BLAH - ntoys: number of toys that will be run to get test statistic distributions - - batch_size: batch size that will be used for the RM fits """ def __init__( self, test_statistic: TestStatistic.__class__, signal_source_names: ty.Tuple[str], background_source_names: ty.Tuple[str], - sources: ty.Dict[str, fd.Source.__class__], - arguments: ty.Dict[str, ty.Dict[str, ty.Union[int, float]]] = None, expected_background_counts: ty.Dict[str, float] = None, gaussian_constraint_widths: ty.Dict[str, float] = None, sample_other_constraints: ty.Dict[str, ty.Callable] = None, - rm_bounds: ty.Dict[str, ty.Tuple[float, float]] = None, - log_constraint_fn: ty.Callable = None, - ntoys=1000, - batch_size=10000): - - for key in sources.keys(): - if key not in arguments.keys(): - arguments[key] = dict() + likelihood=None, + ntoys=1000): if gaussian_constraint_widths is None: gaussian_constraint_widths = dict() @@ -177,34 +168,17 @@ def __init__( if sample_other_constraints is None: sample_other_constraints = dict() - if rm_bounds is None: - rm_bounds = dict() - else: - for bounds in rm_bounds.values(): - assert bounds[0] >= 0., 'Currently do not support negative rate multipliers' - - if log_constraint_fn is None: - def log_constraint_fn(**kwargs): - return 0. - self.log_constraint_fn = log_constraint_fn - else: - self.log_constraint_fn = log_constraint_fn - self.ntoys = ntoys - self.batch_size = batch_size + self.likelihood = likelihood self.test_statistic = test_statistic self.signal_source_names = signal_source_names self.background_source_names = background_source_names - self.sources = sources - self.arguments = arguments - self.expected_background_counts = expected_background_counts self.gaussian_constraint_widths = gaussian_constraint_widths self.sample_other_constraints = sample_other_constraints - self.rm_bounds = rm_bounds def run_routine(self, mus_test=None, save_fits=False, observed_data=None, @@ -212,7 +186,8 @@ def run_routine(self, mus_test=None, save_fits=False, generate_B_toys=False, simulate_dict_B=None, toy_data_B=None, constraint_extra_args_B=None, toy_batch=0, - discovery=False): + SB_toys=False, B_toys=False, discovery_TS=False, + sample_certain_nuisance=False): """If observed_data is passed, evaluate observed test statistics. Otherwise, obtain test statistic distributions (for both S+B and B-only). @@ -257,42 +232,34 @@ def run_routine(self, mus_test=None, save_fits=False, self.toy_batch = toy_batch observed_test_stats_collection = dict() + test_stat_dists_SB_collection = dict() + test_stat_dists_SB_disco_collection = dict() test_stat_dists_B_collection = dict() + test_stat_dists_B_disco_collection = dict() # Loop over signal sources for signal_source in self.signal_source_names: observed_test_stats = ObservedTestStatistics() + test_stat_dists_SB = TestStatisticDistributions() + test_stat_dists_SB_disco = TestStatisticDistributions() test_stat_dists_B = TestStatisticDistributions() - - sources = dict() - arguments = dict() - for background_source in self.background_source_names: - sources[background_source] = self.sources[background_source] - arguments[background_source] = self.arguments[background_source] - sources[signal_source] = self.sources[signal_source] - arguments[signal_source] = self.arguments[signal_source] - - # Create likelihood of TemplateSources - likelihood = fd.LogLikelihood(sources=sources, - arguments=arguments, - progress=False, - batch_size=self.batch_size, - free_rates=tuple([sname for sname in sources.keys()])) - - rm_bounds = dict() - if signal_source in self.rm_bounds.keys(): - rm_bounds[signal_source] = self.rm_bounds[signal_source] - for background_source in self.background_source_names: - if background_source in self.rm_bounds.keys(): - rm_bounds[background_source] = self.rm_bounds[background_source] - - # Pass rate multiplier bounds to likelihood - likelihood.set_rate_multiplier_bounds(**rm_bounds) - - # Pass constraint function to likelihood - likelihood.set_log_constraint(self.log_constraint_fn) + test_stat_dists_B_disco = TestStatisticDistributions() + + # Get likelihood + likelihood = deepcopy(self.likelihood) + + assert hasattr(likelihood, 'likelihoods'), 'Logic only currently works for combined likelihood' + for ll in likelihood.likelihoods.values(): + sources_remove = [] + params_remove = [] + for sname in ll.sources: + if (sname != signal_source) and (sname not in self.background_source_names): + sources_remove.append(sname) + params_remove.append(f'{sname}_rate_multiplier') + likelihood.rebuild(sources_remove=sources_remove, + params_remove=params_remove) # Where we want to generate B-only toys if generate_B_toys: @@ -300,7 +267,8 @@ def run_routine(self, mus_test=None, save_fits=False, constraint_extra_args_B_all = [] for i in tqdm(range(self.ntoys), desc='Background-only toys'): simulate_dict_B, toy_data_B, constraint_extra_args_B = \ - self.sample_data_constraints(0., signal_source, likelihood) + self.sample_data_constraints(0., signal_source, likelihood, + sample_certain_nuisance=sample_certain_nuisance) toy_data_B_all.append(toy_data_B) constraint_extra_args_B_all.append(constraint_extra_args_B) simulate_dict_B.pop(f'{signal_source}_rate_multiplier') @@ -315,28 +283,37 @@ def run_routine(self, mus_test=None, save_fits=False, mu_test, signal_source, likelihood, save_fits=save_fits) # Case where we want test statistic distributions else: - self.toy_test_statistic_dist(test_stat_dists_SB, test_stat_dists_B, + self.toy_test_statistic_dist(test_stat_dists_SB, test_stat_dists_SB_disco, + test_stat_dists_B, test_stat_dists_B_disco, mu_test, signal_source, likelihood, - save_fits=save_fits, discovery=discovery) + save_fits=save_fits, + SB_toys=SB_toys, B_toys=B_toys, discovery_TS=discovery_TS, + sample_certain_nuisance=sample_certain_nuisance) if observed_data is not None: observed_test_stats_collection[signal_source] = observed_test_stats else: test_stat_dists_SB_collection[signal_source] = test_stat_dists_SB + test_stat_dists_SB_disco_collection[signal_source] = test_stat_dists_SB_disco test_stat_dists_B_collection[signal_source] = test_stat_dists_B + test_stat_dists_B_disco_collection[signal_source] = test_stat_dists_B_disco if observed_data is not None: return observed_test_stats_collection else: - return test_stat_dists_SB_collection, test_stat_dists_B_collection + return test_stat_dists_SB_collection, test_stat_dists_SB_disco_collection, \ + test_stat_dists_B_collection, test_stat_dists_B_disco_collection - def sample_data_constraints(self, mu_test, signal_source_name, likelihood): + def sample_data_constraints(self, mu_test, signal_source_name, likelihood, + sample_certain_nuisance=False): """Internal function to sample the toy data and constraint central values following a frequentist procedure. Method taken depends on whether conditional best fits were passed. """ simulate_dict = dict() constraint_extra_args = dict() + + # For rate multipliers for background_source in self.background_source_names: # Case where we use the conditional best fits as constraint centers and simulated values if self.observed_test_stats is not None: @@ -346,7 +323,7 @@ def sample_data_constraints(self, mu_test, signal_source_name, likelihood): conditional_bfs_observed[mu_test][f'{background_source}_rate_multiplier'] except Exception: raise RuntimeError("Could not find observed conditional best fits") - # Case where we use the prior expected counts as constraint centers and simualted values + # Case where we use the prior expected counts as constraint centers and simulated values else: expected_background_counts = self.expected_background_counts[background_source] @@ -360,97 +337,199 @@ def sample_data_constraints(self, mu_test, signal_source_name, likelihood): draw = self.sample_other_constraints[background_source](expected_background_counts) constraint_extra_args[f'{background_source}_expected_counts'] = tf.cast(draw, fd.float_type()) - simulate_dict[f'{background_source}_rate_multiplier'] = expected_background_counts - simulate_dict[f'{signal_source_name}_rate_multiplier'] = mu_test + simulate_dict[f'{background_source}_rate_multiplier'] = tf.cast(expected_background_counts, fd.float_type()) + simulate_dict[f'{signal_source_name}_rate_multiplier'] = tf.cast(mu_test, fd.float_type()) + + # For all other parameters + for param_name in likelihood.param_names: + if '_rate_multiplier' in param_name: + continue + + if param_name not in self.sample_other_constraints.keys(): + continue + + # Case where we use the conditional best fits as constraint centers and simulated values + if self.observed_test_stats is not None: + try: + conditional_bfs_observed = self.observed_test_stats[signal_source_name].conditional_best_fits + param_val_expected = conditional_bfs_observed[mu_test][param_name] + except Exception: + raise RuntimeError("Could not find observed conditional best fits") + # Case where we use the prior expected counts as constraint centers and simulated values + else: + param_val_expected = likelihood.param_defaults[param_name] + + # Sample constraint centers + draw = self.sample_other_constraints[param_name](param_val_expected) + constraint_extra_args[f'{param_name}_expected'] = tf.cast(draw, fd.float_type()) + + if self.observed_test_stats is not None: + conditional_bfs_observed = self.observed_test_stats[signal_source_name].conditional_best_fits[mu_test] + non_rate_params_added = [] + for pname, fitval in conditional_bfs_observed.items(): + if (pname not in simulate_dict) and (pname in likelihood.param_defaults): + simulate_dict[pname] = fitval + non_rate_params_added.append(pname) + + if sample_certain_nuisance: + if 'combined_rate_scaling_expected' in constraint_extra_args: + simulate_dict['combined_rate_scaling'] = constraint_extra_args['combined_rate_scaling_expected'] + constraint_extra_args['combined_rate_scaling_expected'] = 0. toy_data = likelihood.simulate(**simulate_dict) + if self.observed_test_stats is not None: + for pname in non_rate_params_added: + simulate_dict.pop(pname) + return simulate_dict, toy_data, constraint_extra_args - def toy_test_statistic_dist(self, test_stat_dists_SB, test_stat_dists_B, + def toy_test_statistic_dist(self, + test_stat_dists_SB, test_stat_dists_SB_disco, + test_stat_dists_B, test_stat_dists_B_disco, mu_test, signal_source_name, likelihood, - save_fits=False, discovery=False): - """Internal function to get test statistic distribution. + save_fits=False, + SB_toys=False, B_toys=False, discovery_TS=False, + sample_certain_nuisance=False): + """ + Internal function to get test statistic distribution given a signal and POI value. | + test_stat_dists_SB: TestStatisticDistributions, t(mu_test|mu=mu_test) * | + test_stat_dists_SB_disco: TestStatisticDistributions, t(0.|mu=mu_test) * | + test_stat_dists_B: TestStatisticDistributions, t(mu_test|mu=0.) * | + test_stat_dists_B_disco: TestStatisticDistributions, t(0.|mu=0.) * | + mu_test: float, POI test value (usually signal counts). | + signal_source_name: string, the source that takes the POI. | + likelihood: LogLikelihood,the likelihood object. | + save_fits: bool, whether or not to save cond/uncond fits, stored in + TestStatisticDistributions. | + SB_toys: bool, whether or not to simulate S+B toys. | + B_toys: bool, whether or not to simulate B toys. | + discovery_TS: bool, wether to **only** evaluate test_stat_dists_SB_disco + and not test_stat_dists_SB. | + return: None, updates flamedisx TestStatisticDistributions objects in first + inputs (*). | """ ts_values_SB = [] + ts_values_SB_disco = [] ts_values_B = [] + ts_values_B_disco = [] + if save_fits: unconditional_bfs_SB = [] conditional_bfs_SB = [] + + unconditional_bfs_SB_disco = [] + conditional_bfs_SB_disco = [] + unconditional_bfs_B = [] conditional_bfs_B = [] + unconditional_bfs_B_disco = [] + conditional_bfs_B_disco = [] + # Loop over toys for toy in tqdm(range(self.ntoys), desc='Doing toys'): - simulate_dict_SB, toy_data_SB, constraint_extra_args_SB = \ - self.sample_data_constraints(mu_test, signal_source_name, likelihood) - # S+B toys - - # Shift the constraint in the likelihood based on the background RMs we drew - likelihood.set_constraint_extra_args(**constraint_extra_args_SB) - # Set data - likelihood.set_data(toy_data_SB) - # Create test statistic - test_statistic_SB = self.test_statistic(likelihood) - # Guesses for fit - guess_dict_SB = simulate_dict_SB.copy() - for key, value in guess_dict_SB.items(): - if value < 0.1: - guess_dict_SB[key] = 0.1 - # Evaluate test statistic - if discovery: - ts_result_SB = test_statistic_SB(0., signal_source_name, guess_dict_SB) - else: - ts_result_SB = test_statistic_SB(mu_test, signal_source_name, guess_dict_SB) - # Save test statistic, and possibly fits - ts_values_SB.append(ts_result_SB[0]) - if save_fits: - unconditional_bfs_SB.append(ts_result_SB[1]) - conditional_bfs_SB.append(ts_result_SB[2]) - - # B-only toys - - try: + if SB_toys: + simulate_dict_SB, toy_data_SB, constraint_extra_args_SB = \ + self.sample_data_constraints(mu_test, signal_source_name, likelihood, + sample_certain_nuisance) + + # Shift the constraint in the likelihood based on the background RMs we drew + likelihood.set_constraint_extra_args(**constraint_extra_args_SB) + # Set data + if hasattr(likelihood, 'likelihoods'): + for component, data in toy_data_SB.items(): + likelihood.set_data(data, component) + else: + likelihood.set_data(toy_data_SB) + # Create test statistic + test_statistic_SB = self.test_statistic(likelihood) # Guesses for fit - guess_dict_B = self.simulate_dict_B.copy() - guess_dict_B[f'{signal_source_name}_rate_multiplier'] = 0. - for key, value in guess_dict_B.items(): + guess_dict_SB = simulate_dict_SB.copy() + for key, value in guess_dict_SB.items(): if value < 0.1: - guess_dict_B[key] = 0.1 - toy_data_B = self.toy_data_B[toy+(self.toy_batch*self.ntoys)] - constraint_extra_args_B = self.constraint_extra_args_B[toy] - except Exception: - raise RuntimeError("Could not find background-only datasets") - - # Shift the constraint in the likelihood based on the background RMs we drew - likelihood.set_constraint_extra_args(**constraint_extra_args_B) - # Set data - likelihood.set_data(toy_data_B) - # Create test statistic - test_statistic_B = self.test_statistic(likelihood) - # Evaluate test statistic - if discovery: - ts_result_B = test_statistic_B(0., signal_source_name, guess_dict_B) - else: - ts_result_B = test_statistic_B(mu_test, signal_source_name, guess_dict_B) - # Save test statistic, and possibly fits - ts_values_B.append(ts_result_B[0]) - if save_fits: - unconditional_bfs_B.append(ts_result_SB[1]) - conditional_bfs_B.append(ts_result_SB[2]) + guess_dict_SB[key] = 0.1 + # Evaluate and save test statistics + if discovery_TS: + # If we're doing significance, lots of toys + ts_result_SB_disco = test_statistic_SB(0., signal_source_name, guess_dict_SB) + ts_values_SB_disco.append(ts_result_SB_disco[0]) + else: + # If we're doing significance limits, can afford the extra eval. + ts_result_SB_disco = test_statistic_SB(0., signal_source_name, guess_dict_SB) + ts_values_SB_disco.append(ts_result_SB_disco[0]) + ts_result_SB = test_statistic_SB(mu_test, signal_source_name, guess_dict_SB) + ts_values_SB.append(ts_result_SB[0]) + # Possibly save fits + if save_fits: + if discovery_TS: + unconditional_bfs_SB_disco.append(ts_result_SB_disco[1]) + conditional_bfs_SB_disco.append(ts_result_SB_disco[2]) + else: + unconditional_bfs_SB.append(ts_result_SB[1]) + conditional_bfs_SB.append(ts_result_SB[2]) + + # B-only toys + if B_toys: + try: + # Guesses for fit + guess_dict_B = self.simulate_dict_B.copy() + guess_dict_B[f'{signal_source_name}_rate_multiplier'] = 0. + for key, value in guess_dict_B.items(): + if value < 0.1: + guess_dict_B[key] = 0.1 + toy_data_B = self.toy_data_B[toy+(self.toy_batch*self.ntoys)] + constraint_extra_args_B = self.constraint_extra_args_B[toy+(self.toy_batch*self.ntoys)] + except Exception: + raise RuntimeError("Could not find background-only datasets") + + # Shift the constraint in the likelihood based on the background RMs we drew + likelihood.set_constraint_extra_args(**constraint_extra_args_B) + # Set data + if hasattr(likelihood, 'likelihoods'): + for component, data in toy_data_B.items(): + likelihood.set_data(data, component) + else: + likelihood.set_data(toy_data_B) + # Create test statistic + test_statistic_B = self.test_statistic(likelihood) + # Evaluate and save test statistics + if discovery_TS: + ts_result_B_disco = test_statistic_B(0., signal_source_name, guess_dict_B) + ts_values_B_disco.append(ts_result_B_disco[0]) + else: + ts_result_B = test_statistic_B(mu_test, signal_source_name, guess_dict_B) + ts_values_B.append(ts_result_B[0]) + # Possibly save fits + if save_fits: + if discovery_TS: + unconditional_bfs_B_disco.append(ts_result_B_disco[1]) + conditional_bfs_B_disco.append(ts_result_B_disco[2]) + else: + unconditional_bfs_B.append(ts_result_B[1]) + conditional_bfs_B.append(ts_result_B[2]) # Add to the test statistic distributions test_stat_dists_SB.add_ts_dist(mu_test, ts_values_SB) + test_stat_dists_SB_disco.add_ts_dist(mu_test, ts_values_SB_disco) test_stat_dists_B.add_ts_dist(mu_test, ts_values_B) + test_stat_dists_B_disco.add_ts_dist(mu_test, ts_values_B_disco) # Possibly save the fits if save_fits: test_stat_dists_SB.add_unconditional_best_fit(mu_test, unconditional_bfs_SB) test_stat_dists_SB.add_conditional_best_fit(mu_test, conditional_bfs_SB) + + test_stat_dists_SB_disco.add_unconditional_best_fit(mu_test, unconditional_bfs_SB_disco) + test_stat_dists_SB_disco.add_conditional_best_fit(mu_test, conditional_bfs_SB_disco) + test_stat_dists_B.add_unconditional_best_fit(mu_test, unconditional_bfs_B) test_stat_dists_B.add_conditional_best_fit(mu_test, conditional_bfs_B) + test_stat_dists_B_disco.add_unconditional_best_fit(mu_test, unconditional_bfs_B_disco) + test_stat_dists_B_disco.add_conditional_best_fit(mu_test, conditional_bfs_B_disco) + def get_observed_test_stat(self, observed_test_stats, observed_data, mu_test, signal_source_name, likelihood, save_fits=False): """Internal function to evaluate observed test statistic. @@ -464,16 +543,21 @@ def get_observed_test_stat(self, observed_test_stats, observed_data, likelihood.set_constraint_extra_args(**constraint_extra_args) # Set data - likelihood.set_data(observed_data) + if hasattr(likelihood, 'likelihoods'): + for component, data in observed_data.items(): + likelihood.set_data(data, component) + else: + likelihood.set_data(observed_data) + # Create test statistic test_statistic = self.test_statistic(likelihood) # Guesses for fit - guess_dict = {f'{signal_source_name}_rate_multiplier': mu_test} + guess_dict = {f'{signal_source_name}_rate_multiplier': tf.cast(0.1, fd.float_type())} for background_source in self.background_source_names: - guess_dict[f'{background_source}_rate_multiplier'] = self.expected_background_counts[background_source] + guess_dict[f'{background_source}_rate_multiplier'] = tf.cast(self.expected_background_counts[background_source], fd.float_type()) for key, value in guess_dict.items(): if value < 0.1: - guess_dict[key] = 0.1 + guess_dict[key] = tf.cast(0.1, fd.float_type()) # Evaluate test statistic ts_result = test_statistic(mu_test, signal_source_name, guess_dict) @@ -509,12 +593,14 @@ def __init__( signal_source_names: ty.Tuple[str], observed_test_stats: ObservedTestStatistics, test_stat_dists_SB: TestStatisticDistributions, - test_stat_dists_B: TestStatisticDistributions): + test_stat_dists_B: TestStatisticDistributions, + test_stat_dists_SB_disco: TestStatisticDistributions=None): self.signal_source_names = signal_source_names self.observed_test_stats = observed_test_stats self.test_stat_dists_SB = test_stat_dists_SB self.test_stat_dists_B = test_stat_dists_B + self.test_stat_dists_SB_disco = test_stat_dists_SB_disco @staticmethod def interp_helper(x, y, crossing_points, crit_val, @@ -585,6 +671,7 @@ def get_interval(self, conf_level=0.1, pcl_level=0.16, lower_lim_all = dict() upper_lim_all = dict() + upper_lim_all_raw = dict() # Loop over signal sources for signal_source in self.signal_source_names: these_p_sb = p_sb[signal_source] @@ -616,6 +703,7 @@ def get_interval(self, conf_level=0.1, pcl_level=0.16, # Take the highest decreasing crossing point, and interpolate to get an upper limit upper_lim = self.interp_helper(mus, p_vals, upper_lims, conf_level, rising_edge=False, inverse=True) + upper_lim_raw = upper_lim if use_CLs is False: M0 = self.interp_helper(mus, pws, upper_lims, upper_lim, @@ -629,22 +717,32 @@ def get_interval(self, conf_level=0.1, pcl_level=0.16, lower_lim_all[signal_source] = lower_lim upper_lim_all[signal_source] = upper_lim + upper_lim_all_raw[signal_source] = upper_lim_raw if use_CLs is False: - return lower_lim_all, upper_lim_all, p_sb, powers + return lower_lim_all, upper_lim_all, upper_lim_all_raw, p_sb, powers else: return lower_lim_all, upper_lim_all, p_sb, p_b def upper_lims_bands(self, pval_curve, mus, conf_level): - upper_lims = np.argwhere(np.diff(np.sign(pval_curve - np.ones_like(pval_curve) * conf_level)) < 0.).flatten() - return self.interp_helper(mus, pval_curve, upper_lims, conf_level, - rising_edge=False, inverse=True) + try: + upper_lims = np.argwhere(np.diff(np.sign(pval_curve - np.ones_like(pval_curve) * conf_level)) < 0.).flatten() + return self.interp_helper(mus, pval_curve, upper_lims, conf_level, + rising_edge=False, inverse=True) + except Exception: + return 0. + + def critical_disco_value(self, disco_pot_curve, mus, discovery_sigma): + crossing_point = np.argwhere(np.diff(np.sign(disco_pot_curve - np.ones_like(disco_pot_curve) * discovery_sigma)) > 0.).flatten() + return self.interp_helper(mus, disco_pot_curve, crossing_point, discovery_sigma, + rising_edge=True, inverse=True) def get_bands(self, conf_level=0.1, quantiles=[0, 1, -1, 2, -2], - use_CLs=False): + use_CLs=False, return_toy_indices=False): """ """ bands = dict() + toy_indices = dict() # Loop over signal sources for signal_source in self.signal_source_names: @@ -670,35 +768,72 @@ def get_bands(self, conf_level=0.1, quantiles=[0, 1, -1, 2, -2], p_val_curves = np.transpose(np.stack(p_val_curves, axis=0)) upper_lims_bands = np.apply_along_axis(self.upper_lims_bands, 1, p_val_curves, mus, conf_level) + upper_lims_bands_all = upper_lims_bands + if len(upper_lims_bands[upper_lims_bands == 0.]) > 0.: + print(f'Found {len(upper_lims_bands[upper_lims_bands == 0.])} failed toy for {signal_source}; removing...') + upper_lims_bands = upper_lims_bands[upper_lims_bands > 0.] + these_bands = dict() + these_toy_indices = dict() for quantile in quantiles: these_bands[quantile] = np.quantile(np.sort(upper_lims_bands), stats.norm.cdf(quantile)) + + nearest_index = np.argmin(np.abs(upper_lims_bands_all - these_bands[quantile])) + these_toy_indices[quantile] = nearest_index + bands[signal_source] = these_bands + toy_indices[signal_source] = these_toy_indices + + if return_toy_indices: + return bands, toy_indices + else: + return bands + + def get_disco_sig(self): + """ + """ + disco_sigs = dict() + + # Loop over signal sources + for signal_source in self.signal_source_names: + # Get observed (mu = 0) test statistic and B (m = 0) test statistic distribition + try: + observed_test_stat = self.observed_test_stats[signal_source].test_stats[0.] + test_stat_dist_B = self.test_stat_dists_B[signal_source].ts_dists[0.] + except Exception: + raise RuntimeError("Error: did you scan over mu = 0?") + + p_val = (100. - stats.percentileofscore(test_stat_dist_B, + observed_test_stat, + kind='weak')) / 100. + disco_sig = stats.norm.ppf(1. - p_val) + disco_sig = np.where(disco_sig > 0., disco_sig, 0.) + disco_sigs[signal_source] = disco_sig - return bands + return disco_sigs - def get_bands_discovery(self, quantiles=[0, 1, -1]): + def get_median_disco_asymptotic(self, sigma_level=3): """ """ - bands = dict() + medians = dict() # Loop over signal sources for signal_source in self.signal_source_names: # Get test statistic distribitions - test_stat_dists_SB = self.test_stat_dists_SB[signal_source] - test_stat_dists_B = self.test_stat_dists_B[signal_source] + test_stat_dists_SB_disco = self.test_stat_dists_SB_disco[signal_source] - assert len(test_stat_dists_SB.ts_dists.keys()) == 1, 'Currently only support a single signal strength' + mus = [] + disco_sig_curves = [] + # Loop over signal rate multipliers + for mu_test, ts_values in test_stat_dists_SB_disco.ts_dists.items(): + these_disco_sigs = np.sqrt(ts_values) - these_p_vals = (100. - stats.percentileofscore(list(test_stat_dists_B.ts_dists.values())[0], - list(test_stat_dists_SB.ts_dists.values())[0], - kind='weak')) / 100. - these_p_vals = these_p_vals[these_p_vals > 0.] - these_disco_sigs = stats.norm.ppf(1. - these_p_vals) + mus.append(mu_test) + disco_sig_curves.append(these_disco_sigs) - these_bands = dict() - for quantile in quantiles: - these_bands[quantile] = np.quantile(np.sort(these_disco_sigs), stats.norm.cdf(quantile)) - bands[signal_source] = these_bands + disco_sig_curves = np.stack(disco_sig_curves, axis=0) + median_disco_sigs = [np.median(disco_sigs) for disco_sigs in disco_sig_curves] + median_crossing_point = self.critical_disco_value(median_disco_sigs, mus, sigma_level) + medians[signal_source] = median_crossing_point - return bands + return medians diff --git a/flamedisx/templates.py b/flamedisx/templates.py index 610b9be59..6aa15b324 100644 --- a/flamedisx/templates.py +++ b/flamedisx/templates.py @@ -9,8 +9,12 @@ import tensorflow as tf import tensorflow_probability as tfp +from flamedisx.tfbspline import bspline + import flamedisx as fd +from copy import deepcopy + export, __all__ = fd.exporter() @@ -98,7 +102,9 @@ def differential_rates_numpy(self, data): if self._interpolator: # transpose since RegularGridInterpolator expects (n_points, n_dims) - return self._interpolator(data.T) + interp_diff_rates = self._interpolator(data.T) + lookup_diff_rates = self._mh_diff_rate.lookup(*data) + return np.where(interp_diff_rates <= 0., lookup_diff_rates, interp_diff_rates) else: return self._mh_diff_rate.lookup(*data) @@ -197,10 +203,12 @@ class MultiTemplateSource(fd.Source): def __init__( self, params_and_templates: ty.Tuple[ty.Dict[str, float], ty.Any], + params_and_normalisations:ty.Tuple[ty.Dict[str, float], float], bin_edges=None, axis_names=None, events_per_bin=False, interpolate=False, + _skip_tf_init=False, *args, **kwargs): @@ -208,6 +216,11 @@ def __init__( TemplateWrapper( template, bin_edges, axis_names, events_per_bin, interpolate) for _, template in params_and_templates] + assert len(params_and_templates[0][0]) == 1, "This implementation currently only supports moprhing of 1 parameter" + self.param_name = list(params_and_templates[0][0].keys())[0] + + # We will include mu variation separately + self.mu = self._templates[0].mu # Grab parameter names. Promote first set of values to defaults. self.n_templates = n_templates = len(self._templates) @@ -224,9 +237,9 @@ def __init__( # # When evaluated at the exact location of a template, the result has 1 # in the corresponding template's position, and zeros elsewhere. - _template_weights = scipy.interpolate.LinearNDInterpolator( - points=np.asarray([list(params.values()) for params, _ in params_and_templates]), - values=np.eye(n_templates)) + _template_weights = scipy.interpolate.interp1d( + x=np.asarray([list(params.values())[0] for params, _ in params_and_templates]), + y=np.eye(n_templates)) # Unfortunately TensorFlow has no equivalent of LinearNDInterpolator, # only interpolators that work on rectilinear grids. Thus, instead of @@ -240,31 +253,12 @@ def __init__( for params, _ in params_and_templates))) for param in defaults]) _full_grid_coordinates = np.meshgrid(*_grid_coordinates, indexing='ij') - n_grid_points = np.prod([len(x) for x in _grid_coordinates]) # Evaluate our irregular-grid scipy-interpolator on the grid. # This gives an array of shape (n_templates, ngrid_dim0, ngrid_dim1, ...) # for use in tensorflow interpolation. _grid_weights = _template_weights(*_full_grid_coordinates) - # The expected number of events must also be interpolated. - # For consistency, it must be done in the same way (first interpolate - # to a regular grid, then linearly from there). - # (n_templates,) array - _template_mus = np.asarray([ - template.mu for template in self._templates]) - self._grid_mus = np.average( - # numpy won't let us get away with a size-1 axis here, we have to - # actually repeat the values. (If we had jax we could just vmap...) - np.repeat(_template_mus[:, None], n_grid_points, axis=1), - axis=0, - weights=_grid_weights.reshape(n_templates, n_grid_points)) - assert self._grid_mus.shape == (n_templates,) - self._mu_interpolator = scipy.interpolate.RegularGridInterpolator( - points=_grid_coordinates, - values=self._grid_mus.reshape(_full_grid_coordinates[0].shape), - method='linear') - # Generate a random column name to use to store the diff rates # of observed events under every template self.column = ( @@ -274,21 +268,37 @@ def __init__( # ... this column will hold an array, with one entry per template self.array_columns = ((self.column, n_templates),) - # This source has parameters but no model functions, so we can't do the - # usual Source.scan_model_functions. - self.f_dims = dict() - self.f_params = dict() - self.defaults = defaults - # This is needed in tensorflow, so convert it now self._grid_coordinates = tuple([fd.np_to_tf(np.asarray(g)) for g in _grid_coordinates]) self._grid_weights = fd.np_to_tf(_grid_weights) + param_vals = np.asarray([list(params.values())[0] for params, _ in params_and_templates]) + self.pmin = tf.constant(min(param_vals), fd.float_type()) + self.pmax = tf.constant(max(param_vals), fd.float_type()) + pvals = tf.convert_to_tensor(param_vals, fd.float_type()) + + normalisations = np.array([norm for _, norm in params_and_normalisations]) + self.normalisations = tf.convert_to_tensor(normalisations / normalisations[0], + fd.float_type()) + + # Assume equi-spacing! + self.dstep = pvals[1] - pvals[0] + # Need to pad domain.. four might be excessive + try: + self.pvals = list(np.arange(pvals[0] - 4. * self.dstep, pvals[-1] + 4. * self.dstep, self.dstep)) + assert len(self.pvals) == len(pvals) + 8, "Something went wrong with the padding!" + except: + self.pvals = list(np.arange(pvals[0] - 4. * self.dstep, pvals[-1] + 5. * self.dstep, self.dstep)) + assert len(self.pvals) == len(pvals) + 8, "Something went wrong with the padding!" + + self.array_columns = ((self.column, n_templates+8),) + super().__init__(*args, **kwargs) - def scan_model_functions(self): - # Don't do anything here, already set defaults etc. in __init__ above - pass + self.defaults = {**self.defaults,**{k: tf.cast(v, fd.float_type()) for k, v in defaults.items()}} + self.parameter_index = fd.index_lookup_dict(self.defaults.keys()) + if not _skip_tf_init: + self.trace_differential_rate() def extra_needed_columns(self): return super().extra_needed_columns() + [self.column] @@ -296,45 +306,118 @@ def extra_needed_columns(self): def _annotate(self): """Add columns needed in inference to self.data """ - # Get array of differential rates for each template. - # Outer list() is to placate pandas, which does not like array columns.. + #construct tensor of knots + #requires a tensor of elements + #data is stored as [[d_evt1^h1,d_evt1^h2..],[d_evt2^h1,d_evt2^h2..]] + # so just need to construct and x-values object and let data column handle y-values + #with some padding for the domain! + Nk=len(self.pvals) + knot_range=self.pvals[-1]-self.pvals[0] + linear_shift=2*self.dstep/knot_range + start=min(self.pvals) + end=max(self.pvals) + self.original_range=tf.constant(end-start,dtype=fd.float_type()) + self.max_pos=tf.constant(Nk- 2,dtype=fd.float_type()) + + self.start=tf.constant(start,dtype=fd.float_type()) + self.linear_shift=tf.constant(linear_shift,dtype=fd.float_type()) + self.linear_shift_shift=tf.constant(knot_range/2,dtype=fd.float_type()) + self.data[self.column] = list(np.asarray([ template.differential_rates_numpy(self.data) for template in self._templates]).T) + linear_interp_padded_diff_rates=[] + for diff_rate_per_hist in self.data[self.column]: + + if np.sum(diff_rate_per_hist[:2])>0: + left_edge=scipy.interpolate.interp1d( + self.pvals[4:6],diff_rate_per_hist[:2], + kind='linear',fill_value="extrapolate", + bounds_error=False)(self.pvals[:4]) + else: + left_edge=list(np.repeat(diff_rate_per_hist[0],4)) + + if np.sum(diff_rate_per_hist[-2:])>0: + right_edge=scipy.interpolate.interp1d( + self.pvals[-6:-4],diff_rate_per_hist[-2:], + kind='linear',fill_value="extrapolate", + bounds_error=False)(self.pvals[-4:]) + else: + right_edge=list(np.repeat(diff_rate_per_hist[-1],4)) + + linear_interp_padded_diff_rates.append(np.concatenate([left_edge,diff_rate_per_hist,right_edge])) + + self.data[self.column]=linear_interp_padded_diff_rates + self.tensor_xvals=tf.convert_to_tensor([self.pvals for _ in range(self.batch_size)],dtype=fd.float_type()) + def mu_before_efficiencies(self, **params): - return self.estimate_mu(self, **params) + return self.mu + + def estimate_mu(self, n_trials=None, **params): + norm = tfp.math.batch_interp_regular_1d_grid( + x=params[self.param_name], + x_ref_min=self.pmin, + x_ref_max=self.pmax, + y_ref=self.normalisations, + ) + + return tf.reshape(norm, shape=[]) * self.mu + + def bspline_interpolate_per_bin(self, param,knots): + def interp(knots_for_event): + #second order non-cyclical b-spline with varying knots + #returns [x,y] so ignore x + #hackiest shit ever + return tf.reduce_sum(bspline.interpolate(knots_for_event, + self.max_pos*(param-self.start)/self.original_range+self.linear_shift*(param -self.linear_shift_shift), 2, False) \ + * tf.constant([0,1],dtype=fd.float_type())) + #vectorized map over all events + y=tf.vectorized_map(interp,elems=knots) + return y - def estimate_mu(self, **params): - """Estimate the number of events expected from the template source. + def _differential_rate(self, data_tensor, ptensor): + norm = tfp.math.batch_interp_regular_1d_grid( + x=self._fetch_param(self.param_name, ptensor), + x_ref_min=self.pmin, + x_ref_max=self.pmax, + y_ref=self.normalisations, + ) + + knots_per_event=tf.convert_to_tensor([self.tensor_xvals, self._fetch(self.column, data_tensor)],dtype=fd.float_type()) + bspline_diff_rates=self.bspline_interpolate_per_bin(self._fetch_param(self.param_name, ptensor), tf.transpose(knots_per_event,perm=[1,0,2])) + dr=tf.squeeze(norm)*bspline_diff_rates + + return dr + + def simulate(self, n_events, fix_truth=None, full_annotate=False, + keep_padding=False, **params): + """Simulate n events. """ - # TODO: maybe need .item or something here? - return self._mu_interpolator([ - params.get(param, default) - for param, default in self.defaults.items()]) + if fix_truth: + raise NotImplementedError("TemplateSource does not yet support fix_truth") + assert isinstance(n_events, (int, float)), \ + f"n_events must be an int or float, not {type(n_events)}" + + # TODO: all other arguments are ignored, they make no sense + # for this source. Should we warn about this? Remove them from def? + + assert len(self.defaults) == 1 + + template_weights = tfp.math.batch_interp_regular_1d_grid( + x=params[next(iter(self.defaults))], + x_ref_min=self._grid_coordinates[0][0], + x_ref_max=self._grid_coordinates[0][-1], + y_ref=self._grid_weights, + ) - def _differential_rate(self, data_tensor, ptensor): - # Compute template weights at this parameter point - # (n_templates,) tensor - # (The axis order is weird here. It seems to work...) - permutation = ( - [self._grid_weights.ndim - 1] - + list(range(0, self._grid_weights.ndim - 1))) - template_weights = tfp.math.batch_interp_rectilinear_nd_grid( - x=ptensor[None, :], - x_grid_points=self._grid_coordinates, - y_ref=tf.transpose(self._grid_weights, permutation), - axis=1, - )[:, 0] - # Ensure template weights sum to one. template_weights /= tf.reduce_sum(template_weights) - # Fetch precomputed diff rates for each template. - # (n_events, n_templates) tensor - template_diffrates = self._fetch(self.column, data_tensor) + template_epb = [template._mh_events_per_bin for template in self._templates] + template_epb_combine = deepcopy(template_epb[0]) + template_epb_combine.histogram = np.sum([template.histogram * weight for template, weight in + zip(template_epb, template_weights)], axis=0) - # Compute weighted average of diff rates - # (n_events,) tensor - return tf.reduce_sum( - template_diffrates * template_weights[None, :], - axis=1) + return pd.DataFrame(dict(zip( + self._templates[0].axis_names, + template_epb_combine.get_random(n_events).T))) \ No newline at end of file diff --git a/flamedisx/tfbspline/__init__.py b/flamedisx/tfbspline/__init__.py new file mode 100644 index 000000000..e5c1a7763 --- /dev/null +++ b/flamedisx/tfbspline/__init__.py @@ -0,0 +1 @@ +from .bspline import * diff --git a/flamedisx/tfbspline/bspline.py b/flamedisx/tfbspline/bspline.py new file mode 100644 index 000000000..6ae91d0e7 --- /dev/null +++ b/flamedisx/tfbspline/bspline.py @@ -0,0 +1,237 @@ +import enum + +import tensorflow as tf + +class Degree(enum.IntEnum): + """Defines valid degrees for B-spline interpolation.""" + CONSTANT = 0 + LINEAR = 1 + QUADRATIC = 2 + CUBIC = 3 + QUARTIC = 4 + + +def _constant(position: tf.Tensor) -> tf.Tensor: + """B-Spline basis function of degree 0 for positions in the range [0, 1].""" + # A piecewise constant spline is discontinuous at the knots. + return tf.expand_dims(tf.clip_by_value(1.0 + position, 1.0, 1.0), axis=-1) + + +def _linear(position: tf.Tensor) -> tf.Tensor: + """B-Spline basis functions of degree 1 for positions in the range [0, 1].""" + # Piecewise linear splines are C0 smooth. + return tf.stack((1.0 - position, position), axis=-1) + + +def _quadratic(position: tf.Tensor) -> tf.Tensor: + """B-Spline basis functions of degree 2 for positions in the range [0, 1].""" + # We pre-calculate the terms that are used multiple times. + pos_sq = tf.pow(position, 2.0) + + # Piecewise quadratic splines are C1 smooth. + return tf.stack((tf.pow(1.0 - position, 2.0) / 2.0, -pos_sq + position + 0.5, + pos_sq / 2.0), + axis=-1) + + +def _cubic(position: tf.Tensor) -> tf.Tensor: + """B-Spline basis functions of degree 3 for positions in the range [0, 1].""" + # We pre-calculate the terms that are used multiple times. + neg_pos = 1.0 - position + pos_sq = tf.pow(position, 2.0) + pos_cb = tf.pow(position, 3.0) + + # Piecewise cubic splines are C2 smooth. + return tf.stack( + (tf.pow(neg_pos, 3.0) / 6.0, (3.0 * pos_cb - 6.0 * pos_sq + 4.0) / 6.0, + (-3.0 * pos_cb + 3.0 * pos_sq + 3.0 * position + 1.0) / 6.0, + pos_cb / 6.0), + axis=-1) + + +def _quartic(position: tf.Tensor) -> tf.Tensor: + """B-Spline basis functions of degree 4 for positions in the range [0, 1].""" + # We pre-calculate the terms that are used multiple times. + neg_pos = 1.0 - position + pos_sq = tf.pow(position, 2.0) + pos_cb = tf.pow(position, 3.0) + pos_qt = tf.pow(position, 4.0) + + # Piecewise quartic splines are C3 smooth. + return tf.stack( + (tf.pow(neg_pos, 4.0) / 24.0, + (-4.0 * tf.pow(neg_pos, 4.0) + 4.0 * tf.pow(neg_pos, 3.0) + + 6.0 * tf.pow(neg_pos, 2.0) + 4.0 * neg_pos + 1.0) / 24.0, + (pos_qt - 2.0 * pos_cb - pos_sq + 2.0 * position) / 4.0 + 11.0 / 24.0, + (-4.0 * pos_qt + 4.0 * pos_cb + 6.0 * pos_sq + 4.0 * position + 1.0) / + 24.0, pos_qt / 24.0), + axis=-1) + + +def knot_weights( + positions, + num_knots, + degree: int, + cyclical: bool, + sparse_mode: bool = False, + name: str = "bspline_knot_weights" +): + """Function that converts cardinal B-spline positions to knot weights. + + Note: + In the following, A1 to An are optional batch dimensions. + + Args: + positions: A tensor with shape `[A1, .. An]`. Positions must be between + `[0, C - D)` for non-cyclical and `[0, C)` for cyclical splines, where `C` + is the number of knots and `D` is the spline degree. + num_knots: A strictly positive `int` describing the number of knots in the + spline. + degree: An `int` describing the degree of the spline, which must be smaller + than `num_knots`. + cyclical: A `bool` describing whether the spline is cyclical. + sparse_mode: A `bool` describing whether to return a result only for the + knots with nonzero weights. If set to True, the function returns the + weights of only the `degree` + 1 knots that are non-zero, as well as the + indices of the knots. + name: A name for this op. Defaults to "bspline_knot_weights". + + Returns: + A tensor with dense weights for each control point, with the shape + `[A1, ... An, C]` if `sparse_mode` is False. + Otherwise, returns a tensor of shape `[A1, ... An, D + 1]` that contains the + non-zero weights, and a tensor with the indices of the knots, with the type + tf.int32. + + Raises: + ValueError: If degree is greater than 4 or num_knots - 1, or less than 0. + InvalidArgumentError: If positions are not in the right range. + """ + with tf.name_scope(name): + positions = tf.convert_to_tensor(value=positions) + + all_basis_functions = { + # Maps valid degrees to functions. + Degree.CONSTANT: _constant, + Degree.LINEAR: _linear, + Degree.QUADRATIC: _quadratic, + Degree.CUBIC: _cubic, + Degree.QUARTIC: _quartic + } + basis_functions = all_basis_functions[degree] + + if not cyclical and num_knots - degree == 1: + # In this case all weights are non-zero and we can just return them. + if not sparse_mode: + return basis_functions(positions) + else: + shift = tf.zeros_like(positions, dtype=tf.int32) + return basis_functions(positions), shift + + # shape_batch = positions.shape.as_list() + shape_batch = tf.shape(input=positions) + positions = tf.reshape(positions, shape=(-1,)) + + # Calculate the nonzero weights from the decimal parts of positions. + shift = tf.floor(positions) + sparse_weights = basis_functions(positions - shift) + shift = tf.cast(shift, tf.int32) + + if sparse_mode: + # Returns just the weights and the shift amounts, so that tf.gather_nd on + # the knots can be used to sparsely activate knots if needed. + shape_weights = tf.concat( + (shape_batch, tf.constant((degree + 1,), dtype=tf.int32)), axis=0) + sparse_weights = tf.reshape(sparse_weights, shape=shape_weights) + shift = tf.reshape(shift, shape=shape_batch) + return sparse_weights, shift + + num_positions = tf.size(input=positions) + ind_row, ind_col = tf.meshgrid( + tf.range(num_positions, dtype=tf.int32), + tf.range(degree + 1, dtype=tf.int32), + indexing="ij") + + tiled_shifts = tf.reshape( + tf.tile(tf.expand_dims(shift, axis=-1), multiples=(1, degree + 1)), + shape=(-1,)) + ind_col = tf.reshape(ind_col, shape=(-1,)) + tiled_shifts + if cyclical: + ind_col = tf.math.mod(ind_col, num_knots) + indices = tf.stack((tf.reshape(ind_row, shape=(-1,)), ind_col), axis=-1) + shape_indices = tf.concat((tf.reshape( + num_positions, shape=(1,)), tf.constant( + (degree + 1, 2), dtype=tf.int32)), + axis=0) + indices = tf.reshape(indices, shape=shape_indices) + shape_scatter = tf.concat((tf.reshape( + num_positions, shape=(1,)), tf.constant((num_knots,), dtype=tf.int32)), + axis=0) + weights = tf.scatter_nd(indices, sparse_weights, shape_scatter) + shape_weights = tf.concat( + (shape_batch, tf.constant((num_knots,), dtype=tf.int32)), axis=0) + return tf.reshape(weights, shape=shape_weights) + + +def interpolate_with_weights( + knots, + weights, + name: str = "bspline_interpolate_with_weights") -> tf.Tensor: + """Interpolates knots using knot weights. + + Note: + In the following, A1 to An, and B1 to Bk are optional batch dimensions. + + Args: + knots: A tensor with shape `[B1, ..., Bk, C]` containing knot values, where + `C` is the number of knots. + weights: A tensor with shape `[A1, ..., An, C]` containing dense weights for + the knots, where `C` is the number of knots. + name: A name for this op. Defaults to "bspline_interpolate_with_weights". + + Returns: + A tensor with shape `[A1, ..., An, B1, ..., Bk]`, which is the result of + spline interpolation. + + Raises: + ValueError: If the last dimension of knots and weights is not equal. + """ + with tf.name_scope(name): + knots = tf.convert_to_tensor(value=knots) + weights = tf.convert_to_tensor(value=weights) + + return tf.tensordot(weights, knots, (-1, -1)) + + +def interpolate(knots, + positions, + degree: int, + cyclical: bool, + name: str = "bspline_interpolate") -> tf.Tensor: + """Applies B-spline interpolation to input control points (knots). + + Note: + In the following, A1 to An, and B1 to Bk are optional batch dimensions. + + Args: + knots: A tensor with shape `[B1, ..., Bk, C]` containing knot values, where + `C` is the number of knots. + positions: Tensor with shape `[A1, .. An]`. Positions must be between + `[0, C - D)` for non-cyclical and `[0, C)` for cyclical splines, where `C` + is the number of knots and `D` is the spline degree. + degree: An `int` between 0 and 4, or an enumerated constant from the Degree + class, which is the degree of the splines. + cyclical: A `bool`, whether the splines are cyclical. + name: A name for this op. Defaults to "bspline_interpolate". + + Returns: + A tensor of shape `[A1, ... An, B1, ..., Bk]`, which is the result of spline + interpolation. + """ + with tf.name_scope(name): + knots = tf.convert_to_tensor(value=knots) + positions = tf.convert_to_tensor(value=positions) + + num_knots = knots.get_shape().as_list()[-1] + weights = knot_weights(positions, num_knots, degree, cyclical, False, name) + return interpolate_with_weights(knots, weights)