diff --git a/docs/fft.rst b/docs/fft.rst new file mode 100644 index 00000000..66380e1d --- /dev/null +++ b/docs/fft.rst @@ -0,0 +1,121 @@ + +.. module:: enterprise + :noindex: + +.. note:: This tutorial was generated from a Jupyter notebook that can be + downloaded `here <_static/notebooks/mdc.ipynb>`_. + +.. _mdc: + +Red noise modeling +======================= + +In the beginning of Enterprise red noise modeling, the red noise prior was +always modeled as a diagonal matrix, meaning that the Fourier coefficients +were assumed to be uncorrelated. This model was introduced by Lentati et al. +(2013), and explained by van Haasteren and Vallisneri (2014). In practice this +has been a good-enough approximation, but it is not exact. + +As of early 2025 we now have a more realistic implementation of red noise +priors that include correlations between the basis functions. The `FFT` +method as it is called is a rank-reduced time-domain implementation, meaning +it does not rely on Fourier modes, but on regularly sampled coarse grained +time samples. Here we briefly explain how to use it. + + + +Red noise modeling +------------------- + +The traditional old-style way of modeling was done like: + +.. code:: python + + rn_pl = powerlaw(log10_A=rn_log10_A, gamma=rn_gamma) + rn_phi = gp_signals.FourierBasisGP(spectrum=rn_pl, components=n_components, Tspan=Tspan) + +For the FFT time-domain model, one would do: + +.. code:: python + + rn_pl = powerlaw(log10_A=rn_log10_A, gamma=rn_gamma) + rn_fft = gp_signals.FFTBasisGP(spectrum=rn_pl, components=n_components, oversample=3, cutoff=3, Tspan=Tspan, start_time=start_time) + +The same spectral function can be used. Free spectrum is NOT supported yet. +Instead of `components`, we can also pass `knots=`, where it is understood that +`knots=2*n_components+1`. This is because `components` actually means +frequencies. In the time-domain, the number of `knots` is the number of +`modes+1`, because we cannot just omit the DC term. + +The `oversample` parameter determines how densely the PSD is sampled in +frequencies. With `oversample=1` we would use frequencies at spacing of +`df=1/T`. With `oversample=3` (the default), the frequency spacing is +`df=1/(3T)`. Note that this is a way to numerically approximate the +Wiener-Khinchin integral. With oversample sufficiently large, the FFT is an +excellent approximation of the analytical integral. For powerlaw signals, +`oversample=3` seems a very reasonable choice. + +The `cutoff` parameter is used to specify below which frequency `fcut = 1 / +(cutoff*Tspan)` we set the PSD equal to zero. Note that this parameterization +(which is also in Discovery) is a bit ambiguous, as fcut may not correspond to +an actual bin of the FFT: especially if oversample is not a high number this +can cause a mismatch. In case of a mismatch, `fcut` is rounded up to the next +oversampled-FFT frequency bin. Instead of `cutoff`, the parameter `cutbins` can +also be used (this overrides cutoff). With cutbins the low-frequency cutoff is +set at: `fcut = cutbins / (oversample * Tspan)`, and its interpretation is less +ambiguous: it is the number of bins of the over-sampled PSD of the FFT that is +being zeroed out. + +Common signals +-------------- + +For common signals, instead of: + +.. code:: python + + gw_pl = powerlaw(log10_A=gw_log10_A, gamma=gw_gamma) + orf = utils.hd_orf() + crn_phi = gp_signals.FourierBasisCommonGP(gw_pl, orf, components=20, name='gw', Tspan=Tspan) + + +one would do: + +.. code:: python + + gw_pl = powerlaw(log10_A=gw_log10_A, gamma=gw_gamma) + orf = utils.hd_orf() + crn_fft = gp_signals.FFTBasisCommonGP(gw_pl, orf, components=20, name='gw', Tspan=Tspan, start_time=start_time) + +Chromatic signals +----------------- + +DM-variations and Chromatic noise can be similarly set up: + +.. code:: python + + nknots = 81 + dm_basis = utils.create_fft_time_basis_dm(nknots=nknots) + dm_pl = powerlaw(log10_A=dm_log10_A, gamma=dm_gamma) + dm_fft = gp_signals.FFTBasisGP(dm_pl, basis=dm_basis, nknots=nknots, name='dmgp') + + chrom_basis = utils.create_fft_time_basis_chromatic(nknots=nknots, idx=chrom_idx) + chrom_pl = powerlaw(log10_A=chrom_log10_A, gamma=chrom_gamma) + chrom_fft = gp_signals.FFTBasisGP(chrom_pl, basis=chrom_basis, nknots=nknots, name='chromgp') + +Subtleties +---------- + +Enterprise allows one to combine basis functions when they are the same. This +is especially useful when analyzing common signals which have the same basis as +a single-pulsar signal, such as one would have with red noise and a correlated +GWB. This can be done with the `combine=True` option in `FFTBasisGP` and +`FFTBasisCommonGP`. Default is `combine=True`. The subtlety is that modern PTA +datasets typically have large gaps, which causes some of the time-domain basis +functions to basically be all zeros. Therefore, some basis functions that you +would not expect to be identical will be combined. + +The above is not a bug. Combining such bases and the corresponding Phi matrix +does not matter, because the basis is zero, and that part of the signal has no +bearing on the data or the model. However, when doing signal reconstruction, +such as with `la_forge` or `utils.ConditionalGP`, make sure to set +`combine=False`. diff --git a/docs/index.rst b/docs/index.rst index f7d95823..7ff78a17 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -42,6 +42,7 @@ searches, and timing model analysis. mdc nano9 + fft .. toctree:: :maxdepth: 2 diff --git a/enterprise/signals/gp_bases.py b/enterprise/signals/gp_bases.py index f924be12..44840ebe 100644 --- a/enterprise/signals/gp_bases.py +++ b/enterprise/signals/gp_bases.py @@ -5,6 +5,7 @@ import numpy as np from enterprise.signals.parameter import function +import scipy.interpolate as sint ###################################### # Fourier-basis signal functions ##### @@ -12,12 +13,15 @@ __all__ = [ "createfourierdesignmatrix_red", + "create_fft_time_basis", "createfourierdesignmatrix_dm", + "create_fft_time_basis_dm", "createfourierdesignmatrix_dm_tn", "createfourierdesignmatrix_env", "createfourierdesignmatrix_ephem", "createfourierdesignmatrix_eph", "createfourierdesignmatrix_chromatic", + "create_fft_time_basis_chromatic", "createfourierdesignmatrix_general", ] @@ -91,6 +95,38 @@ def createfourierdesignmatrix_red( return F, Ffreqs +@function +def create_fft_time_basis(toas, nknots=30, Tspan=None, start_time=None, order=1): + """ + Construct coarse time-domain design matrix from eq 11 of Chrisostomi et al., 2025 + :param toas: vector of time series in seconds + :param nknots: number of coarse time samples to use (knots) + :param Tspan: option to some other Tspan + :param start_time: option to set some other start epoch of basis + :param order: order of the interpolation (1 = linear) + + :return B: coarse time-domain design matrix + :return t_coarse: timestamps of coarse time grid + """ + if start_time is None: + start_time = np.min(toas) + else: + if start_time > np.min(toas): + raise ValueError("Coarse time basis start must be earlier than earliest TOA.") + + if Tspan is None: + Tspan = np.max(toas) - start_time + else: + if start_time + Tspan < np.max(toas): + raise ValueError("Coarse time basis end must be later than latest TOA.") + + t_fine = toas + t_coarse = np.linspace(start_time, start_time + Tspan, nknots) + Bmat = sint.interp1d(t_coarse, np.identity(nknots), kind=order)(t_fine).T + + return Bmat, t_coarse + + @function def createfourierdesignmatrix_dm( toas, freqs, nmodes=30, Tspan=None, pshift=False, fref=1400, logf=False, fmin=None, fmax=None, modes=None @@ -127,6 +163,34 @@ def createfourierdesignmatrix_dm( return F * Dm[:, None], Ffreqs +@function +def create_fft_time_basis_dm(toas, freqs, nknots=30, Tspan=None, start_time=None, fref=1400, order=1): + """ + Construct DM-variation linear interpolation design matrix. Current + normalization expresses DM signal as a deviation [seconds] + at fref [MHz] + + :param toas: vector of time series in seconds + :param freqs: radio frequencies of observations [MHz] + :param nknots: number of coarse time samples to use (knots) + :param Tspan: option to some other Tspan + :param start_time: option to set some other start epoch of basis + :param fref: reference frequency [MHz] + :param order: order of the interpolation (1 = linear) + + :return B: coarse time-domain design matrix + :return t_coarse: timestamps of coarse time grid + """ + + # get base course time-domain matrix and times + Bmat, t_coarse = create_fft_time_basis(toas, nknots=nknots, Tspan=Tspan, start_time=start_time, order=order) + + # compute the DM-variation vectors + Dm = (fref / freqs) ** 2 + + return Bmat * Dm[:, None], t_coarse + + @function def createfourierdesignmatrix_dm_tn( toas, freqs, nmodes=30, Tspan=None, pshift=False, fref=1400, logf=False, fmin=None, fmax=None, idx=2, modes=None @@ -292,6 +356,35 @@ def createfourierdesignmatrix_chromatic( return F * Dm[:, None], Ffreqs +@function +def create_fft_time_basis_chromatic(toas, freqs, nknots=30, Tspan=None, start_time=None, fref=1400, idx=4, order=1): + """ + Construct scattering linear interpolation design matrix. Current + normalization expresses DM signal as a deviation [seconds] + at fref [MHz] + + :param toas: vector of time series in seconds + :param freqs: radio frequencies of observations [MHz] + :param nknots: number of coarse time samples to use (knots) + :param Tspan: option to some other Tspan + :param start_time: option to set some other start epoch of basis + :param fref: reference frequency [MHz] + :param idx: Index of chromatic effects + :param order: order of the interpolation (1 = linear) + + :return B: coarse time-domain design matrix + :return t_coarse: timestamps of coarse time grid + """ + + # get base course time-domain matrix and times + Bmat, t_coarse = create_fft_time_basis(toas, nknots=nknots, Tspan=Tspan, start_time=start_time, order=order) + + # compute the DM-variation vectors + Dm = (fref / freqs) ** idx + + return Bmat * Dm[:, None], t_coarse + + @function def createfourierdesignmatrix_general( toas, diff --git a/enterprise/signals/gp_priors.py b/enterprise/signals/gp_priors.py index e515552e..4d1a42ec 100644 --- a/enterprise/signals/gp_priors.py +++ b/enterprise/signals/gp_priors.py @@ -67,6 +67,15 @@ def t_process_adapt(f, log10_A=-15, gamma=4.33, alphas_adapt=None, nfreq=None): return powerlaw(f, log10_A=log10_A, gamma=gamma) * alpha_model +@function +def powerlaw_flat_tail(f, log10_A=-16, gamma=5, log10_kappa=-7, components=2): + """Powerlaw with a flat tail (similar to broken powerlaw)""" + df = np.diff(np.concatenate((np.array([0]), f[::components]))) + pl = (10**log10_A) ** 2 / 12.0 / np.pi**2 * const.fyr ** (gamma - 3) * f ** (-gamma) * np.repeat(df, components) + flat = 10 ** (2 * log10_kappa) + return np.maximum(pl, flat) + + def InvGammaPrior(value, alpha=1, gamma=1): """Prior function for InvGamma parameters.""" return scipy.stats.invgamma.pdf(value, alpha, scale=gamma) diff --git a/enterprise/signals/gp_signals.py b/enterprise/signals/gp_signals.py index a504a410..c048de49 100644 --- a/enterprise/signals/gp_signals.py +++ b/enterprise/signals/gp_signals.py @@ -82,6 +82,14 @@ def _do_selection(self, psr, priorfn, basisfn, coefficients, selection): self._coefficients[key] = cpar self._params[cpar.name] = cpar + @property + def prior_params(self): + """Get any varying prior parameters.""" + ret = [] + for prior in self._prior.values(): + ret.extend([pp.name for pp in prior.params]) + return ret + @property def basis_params(self): """Get any varying basis parameters.""" @@ -114,6 +122,12 @@ def _construct_basis(self, params={}): self._slices.update({key: slice(nctot, nn + nctot)}) nctot += nn + @signal_base.cache_call("prior_params") + def _construct_prior(self, params): + for key, slc in self._slices.items(): + phislc = self._prior[key](self._labels[key], params=params) + self._phi = self._phi.set(phislc, slc) + # this class does different things (and gets different method # definitions) if the user wants it to model GP coefficients # (e.g., for a hierarchical likelihood) or if they do not @@ -173,10 +187,8 @@ def get_basis(self, params={}): def get_phi(self, params): self._construct_basis(params) + self._construct_prior(params) - for key, slc in self._slices.items(): - phislc = self._prior[key](self._labels[key], params=params) - self._phi = self._phi.set(phislc, slc) return self._phi def get_phiinv(self, params): @@ -216,6 +228,87 @@ class FourierBasisGP(BaseClass): return FourierBasisGP +def FFTBasisGP( + spectrum, + basis=None, + coefficients=False, + combine=True, + components=20, + nknots=None, + selection=Selection(selections.no_selection), + oversample=3, + fmax_factor=1, + cutoff=None, + cutbins=1, + Tspan=None, + start_time=None, + interpolation_order=1, + name="red_noise", +): + """Function to return a BasisGP class with a coarse time basis.""" + + if nknots is None: + nknots = 2 * components + 1 + + elif nknots is not None and nknots % 2 == 0: + raise ValueError("Knots needs to be an odd number") + + if cutoff is not None: + # :param cutoff: frequency 1 / (cutoff * T) at which to do + # low-frequency cut-off of the PSD + cutbins = int(np.ceil(oversample / cutoff)) + + fmax_factor = int(fmax_factor) if fmax_factor >= 1 else 1 + + if basis is None: + basis = utils.create_fft_time_basis( + nknots=nknots, Tspan=Tspan, start_time=start_time, order=interpolation_order + ) + + BaseClass = BasisGP(spectrum, basis, coefficients=coefficients, combine=combine, selection=selection, name=name) + + class FFTBasisGP(BaseClass): + signal_type = "basis" + signal_name = "red noise" + signal_id = name + + @signal_base.cache_call("prior_params") + def _construct_prior(self, params): + for key, slc in self._slices.items(): + t_knots = self._labels[key] + + freqs = utils.knots_to_freqs(t_knots, oversample=oversample, fmax_factor=fmax_factor) + + # Hack, because Enterprise adds in f=0 and then calculates df, + # meaning we cannot simply start freqs from 0. Thus, we use + # a modified frequency spacing, such that: + # [0, f1, 0, f1, f2, f3] => [df, -df, df, df, df] + if cutbins == 0: + freqs_prior = np.concatenate([[freqs[1]], freqs]) + psd_prior = self._prior[key](freqs_prior, params=params, components=1) + psd = np.concatenate([[-psd_prior[1]], psd_prior[2:]]) + + else: + psd_prior = self._prior[key](freqs[1:], params=params, components=1) + psd = np.concatenate([np.zeros(cutbins), psd_prior[cutbins - 1 :]]) + + phislc = utils.psd2cov(t_knots, psd, fmax_factor=fmax_factor) + self._phi = self._phi.set(phislc, slc) + + if coefficients: + raise NotImplementedError("Coefficients not supported for FFTBasisGP") + + else: + + def get_phi(self, params): + self._construct_basis(params) + self._construct_prior(params) + + return self._phi + + return FFTBasisGP + + def get_timing_model_basis(use_svd=False, normed=True, idx_exclude=None): if use_svd: if normed is not True: @@ -334,6 +427,11 @@ def __init__(self, psr): self._coefficients[""] = cpar self._params[cpar.name] = cpar + @property + def prior_params(self): + """Get any varying prior parameters.""" + return [pp.name for pp in self._prior.params] + @property def basis_params(self): """Get any varying basis parameters.""" @@ -346,6 +444,10 @@ def basis_params(self): def _construct_basis(self, params={}): self._basis, self._labels = self._bases(params=params) + @signal_base.cache_call("prior_params") + def _construct_prior(self, params): + return BasisCommonGP._prior(self._labels, params=params) + if coefficients: def _get_coefficient_logprior(self, c, **params): @@ -395,8 +497,7 @@ def get_basis(self, params={}): def get_phi(self, params): self._construct_basis(params) - - prior = BasisCommonGP._prior(self._labels, params=params) + prior = self._construct_prior(params) orf = BasisCommonGP._orf(self._psrpos, self._psrpos, params=params) return prior * orf @@ -426,7 +527,6 @@ def FourierBasisCommonGP( pshift=False, pseed=None, ): - if coefficients and Tspan is None: raise ValueError( "With coefficients=True, FourierBasisCommonGP " + "requires that you specify Tspan explicitly." @@ -454,7 +554,7 @@ def __init__(self, psr): # since this function has side-effects, it can only be cached # with limit=1, so it will run again if called with params different # than the last time - @signal_base.cache_call("basis_params", 1) + @signal_base.cache_call("basis_params", limit=1) def _construct_basis(self, params={}): span = Tspan if Tspan is not None else max(FourierBasisCommonGP._Tmax) - min(FourierBasisCommonGP._Tmin) self._basis, self._labels = self._bases(params=params, Tspan=span) @@ -462,6 +562,121 @@ def _construct_basis(self, params={}): return FourierBasisCommonGP +def FFTBasisCommonGP( + spectrum, + orf, + coefficients=False, + combine=True, + components=20, + nknots=None, + Tspan=None, + start_time=None, + cutoff=None, + cutbins=1, + oversample=3, + fmax_factor=1, + interpolation_order=1, + name="common_fft", +): + if coefficients and (Tspan is None or start_time is None): + raise ValueError( + "With coefficients=True, FFTBasisCommonGP " + "requires that you specify Tspan/start_time explicitly." + ) + + if nknots is None: + nknots = 2 * components + 1 + + elif nknots is not None and nknots % 2 == 0: + raise ValueError("Knots needs to be an odd number") + + if cutoff is not None: + # :param cutoff: frequency 1 / (cutoff * T) at which to do + # low-frequency cut-off of the PSD + cutbins = int(np.ceil(oversample / cutoff)) + + fmax_factor = int(fmax_factor) if fmax_factor >= 1 else 1 + + basis = utils.create_fft_time_basis(nknots=nknots, Tspan=Tspan, start_time=start_time, order=interpolation_order) + BaseClass = BasisCommonGP(spectrum, basis, orf, coefficients=coefficients, combine=combine, name=name) + + class FFTBasisCommonGP(BaseClass): + signal_type = "common basis" + signal_name = "common red noise" + signal_id = name + + _Tmin, _Tmax = [], [] + + def __init__(self, psr): + super(FFTBasisCommonGP, self).__init__(psr) + + if start_time is None: + FFTBasisCommonGP._Tmin.append(psr.toas.min()) + + if Tspan is None: + FFTBasisCommonGP._Tmax.append(psr.toas.max()) + + # since this function has side-effects, it can only be cached + # with limit=1, so it will run again if called with params different + # than the last time + @signal_base.cache_call("basis_params", limit=1) + def _construct_basis(self, params={}): + start = start_time if start_time is not None else min(FFTBasisCommonGP._Tmin) + span = Tspan if Tspan is not None else max(FFTBasisCommonGP._Tmax) - start + self._basis, self._labels = self._bases(params=params, Tspan=span, start_time=start) + + self._t_knots = self._labels + freqs = utils.knots_to_freqs(self._t_knots, oversample=oversample, fmax_factor=fmax_factor) + self._freqs = freqs + + @signal_base.cache_call("prior_params") + def _construct_prior(self, params): + """ + Compute and cache the time-domain covariance ('phi') for *this* signal's basis. + """ + + # Hack, because Enterprise adds in f=0 and then calculates df, + # meaning we cannot simply start freqs from 0. Thus, we use + # a modified frequency spacing, such that: + # [0, f1, 0, f1, f2, f3] => [df, -df, df, df, df] + if cutbins == 0: + freqs_prior = np.concatenate([[self._freqs[1]], self._freqs]) + psd_prior = FFTBasisCommonGP._prior(freqs_prior, params=params, components=1) + psd = np.concatenate([[-psd_prior[1]], psd_prior[2:]]) + + else: + psd_prior = FFTBasisCommonGP._prior(self._freqs[1:], params=params, components=1) + psd = np.concatenate([np.zeros(cutbins), psd_prior[cutbins - 1 :]]) + + return utils.psd2cov(self._t_knots, psd) + + if coefficients: + raise NotImplementedError("Coefficients not supported for FFTBasisCommonGP") + + else: + + def get_phi(self, params): + """Over-load constructing Phi to deal with the FFT""" + self._construct_basis(params) + phi = self._construct_prior(params) + + orf = FFTBasisCommonGP._orf(self._psrpos, self._psrpos, params=params) + + return orf * phi + + @classmethod + def get_phicross(cls, signal1, signal2, params): + """Use Phi from signal1, ORF from signal1 vs signal2""" + + phi1 = signal1._construct_prior(params) + # phi2 = signal2._construct_prior(params) + + orf = FFTBasisCommonGP._orf(signal1._psrpos, signal2._psrpos, params=params) + + return phi1 * orf + + return FFTBasisCommonGP + + # for simplicity, we currently do not handle Tspan automatically def FourierBasisCommonGP_ephem(spectrum, components, Tspan, name="ephem_gp"): basis = utils.createfourierdesignmatrix_ephem(nmodes=components, Tspan=Tspan) diff --git a/enterprise/signals/signal_base.py b/enterprise/signals/signal_base.py index f6da48ac..2ec768fc 100644 --- a/enterprise/signals/signal_base.py +++ b/enterprise/signals/signal_base.py @@ -799,6 +799,7 @@ def _set_cache_parameters(self): self.white_params = [] self.basis_params = [] self.delay_params = [] + self.prior_params = [] for signal in self._signals: if signal.signal_type == "white noise": self.white_params.extend(signal.ndiag_params) @@ -807,6 +808,7 @@ def _set_cache_parameters(self): # for common GPs, which do not have coefficients yet self.delay_params.extend(getattr(signal, "delay_params", [])) self.basis_params.extend(signal.basis_params) + self.prior_params.extend(getattr(signal, "prior_params", [])) elif signal.signal_type in ["deterministic"]: self.delay_params.extend(signal.delay_params) else: diff --git a/enterprise/signals/utils.py b/enterprise/signals/utils.py index 333368d4..06912f1e 100644 --- a/enterprise/signals/utils.py +++ b/enterprise/signals/utils.py @@ -20,12 +20,15 @@ from enterprise import signals as sigs # noqa: F401 from enterprise.signals.gp_bases import ( # noqa: F401 createfourierdesignmatrix_red, + create_fft_time_basis, createfourierdesignmatrix_dm, + create_fft_time_basis_dm, createfourierdesignmatrix_dm_tn, createfourierdesignmatrix_env, createfourierdesignmatrix_ephem, createfourierdesignmatrix_eph, createfourierdesignmatrix_chromatic, + create_fft_time_basis_chromatic, createfourierdesignmatrix_general, ) from enterprise.signals.gp_priors import powerlaw, turnover # noqa: F401 @@ -839,6 +842,60 @@ def linear_interp_basis(toas, dt=30 * 86400): return M[:, idx], x[idx] +def psd2cov(t_knots, psd, fmax_factor=1): + """ + Convert a power spectral density function, defined by (freqs, psd), to a covariance matrix. + + :param t_knots: Timestamps of the coarse time grid. + :param psd: PSD values evaluated at frequencies from knots_to_freqs + (assumes *delta_f in psd, so units of [s^2]). + :param fmax_factor: Integer factor to scale up fmax. + + :return covmat: Covariance matrix at the coarse time grid. + """ + + def toeplitz(c): + c = np.asarray(c) + n = len(c) + i = np.arange(n).reshape(-1, 1) + j = np.arange(n).reshape(1, -1) + return c[np.abs(i - j)] + + # Create the full symmetric PSD (excluding duplicate Nyquist term) + fullpsd = np.concatenate([psd, psd[-2:0:-1]]) + + # Compute the inverse FFT + Cfreq = np.fft.ifft(fullpsd, norm="backward") + Ctau = Cfreq.real * len(fullpsd) / 2 + + # With fmax_factor > 1, the IFFT time grid is finer by that factor. + # Slice out every fmax_factor-th sample to match the coarse grid. + return toeplitz(Ctau[::fmax_factor][: len(t_knots)]) + + +def knots_to_freqs(t_knots, oversample=3, fmax_factor=1): + """ + Convert knots of coarse time grid to frequencies + + :param t_knots: Timestamps of the coarse time grid + :param oversample: amount by which to over-sample the frequency grid + :param fmax_factor: Integer factor to scale the maximum frequency. + + :return freqs: Frequencies, regularly sampled with + delta-f = 1/(oversample*T), fmax=1/(2*delta_t_knots) + """ + nmodes = len(t_knots) + Tspan = np.max(t_knots) - np.min(t_knots) + + if nmodes % 2 == 0: + raise ValueError("len(t_knots) must be odd.") + + n_freqs = int((nmodes - 1) / 2 * oversample * fmax_factor + 1) + fmax = (nmodes - 1) / (2 * Tspan) + + return np.linspace(0, fmax * fmax_factor, n_freqs) + + # overlap reduction functions diff --git a/tests/test_gp_signals.py b/tests/test_gp_signals.py index 24197488..43956ae3 100644 --- a/tests/test_gp_signals.py +++ b/tests/test_gp_signals.py @@ -15,7 +15,7 @@ import scipy.linalg as sl from enterprise.pulsar import Pulsar -from enterprise.signals import gp_signals, parameter, selections, signal_base, utils +from enterprise.signals import gp_signals, parameter, selections, signal_base, utils, white_signals from enterprise.signals.selections import Selection from tests.enterprise_test_data import datadir from tests.enterprise_test_data import LIBSTEMPO_INSTALLED, PINT_INSTALLED @@ -36,6 +36,23 @@ def se_kernel(etoas, log10_sigma=-7, log10_lam=np.log10(30 * 86400)): return 10 ** (2 * log10_sigma) * np.exp(-(tm**2) / 2 / 10 ** (2 * log10_lam)) + d +@signal_base.function +def psd_matern32(f, length_scale=365 * 86400.0, log10_sigma_sqr=-14, components=2): + df = np.diff(np.concatenate((np.array([0]), f[::components]))) + return ( + (10**log10_sigma_sqr) + * 24 + * np.sqrt(3) + * length_scale + / (3 + (2 * np.pi * f * length_scale) ** 2) ** 2 + * np.repeat(df, components) + ) + + +def matern32_kernel(tau, length_scale=365 * 86400.0, log10_sigma_sqr=-14): + return (10**log10_sigma_sqr) * (1 + np.sqrt(3) * tau / length_scale) * np.exp(-np.sqrt(3) * tau / length_scale) + + class TestGPSignals(unittest.TestCase): @classmethod def setUpClass(cls): @@ -355,6 +372,81 @@ def test_fourier_red_noise_backend(self): msg = "F matrix shape incorrect" assert rnm.get_basis(params).shape == F.shape, msg + def test_fft_red_noise(self): + """Test the FFT implementation of red noise signals""" + # set up signal parameter + mpsd = psd_matern32( + length_scale=parameter.Uniform(365 * 86400.0, 3650 * 86400.0), log10_sigma_sqr=parameter.Uniform(-17, -9) + ) + rn_cb0 = gp_signals.FFTBasisGP(spectrum=mpsd, components=15, oversample=3, cutbins=0) + rn_cb1 = gp_signals.FFTBasisGP(spectrum=mpsd, nknots=31, oversample=3, cutoff=3) + rnm0 = rn_cb0(self.psr) + rnm1 = rn_cb1(self.psr) + + # parameter values + length_scale, log10_sigma_sqr = 1.5 * 365 * 86400.0, -14.0 + params = { + "B1855+09_red_noise_length_scale": length_scale, + "B1855+09_red_noise_log10_sigma_sqr": log10_sigma_sqr, + } + + # basis matrix test + start_time = np.min(self.psr.toas) + Tspan = np.max(self.psr.toas) - start_time + B, tc = utils.create_fft_time_basis(self.psr.toas, nknots=31) + B1, _ = utils.create_fft_time_basis(self.psr.toas, nknots=31, Tspan=Tspan, start_time=start_time) + + msg = "B matrix incorrect for GP FFT signal." + assert np.allclose(B, rnm0.get_basis(params)), msg + assert np.allclose(B1, rnm1.get_basis(params)), msg + assert np.allclose(np.sum(B, axis=1), np.ones(B.shape[0])), msg + + # spectrum test + tau = np.abs(tc[:, None] - tc[None, :]) + phi_K = matern32_kernel(tau, length_scale, log10_sigma_sqr) + phi_E = rnm0.get_phi(params) + + msg = "Prior incorrect for GP FFT signal." + assert np.allclose(phi_K, phi_E), msg + + # spectrum test with low-frequency cut-off + freqs = utils.knots_to_freqs(tc, oversample=3) + psd = psd_matern32(freqs[1:], length_scale=length_scale, log10_sigma_sqr=log10_sigma_sqr, components=1) + psd = np.concatenate([[0.0], psd]) + phi_K = utils.psd2cov(tc, psd) + phi_E = rnm1.get_phi(params) + + msg = f"Prior incorrect for GP FFT signal." + assert np.allclose(phi_K, phi_E), msg + + def test_fft_common(self): + """Test the FFT implementation of common red noise signals""" + # set up signal parameters + log10_A, gamma = -14.5, 4.33 + params = {"B1855+09_red_noise_log10_A": log10_A, "B1855+09_red_noise_gamma": gamma} + pl = utils.powerlaw(log10_A=parameter.Uniform(-18, -12), gamma=parameter.Uniform(1, 7)) + orf = utils.hd_orf() + + # set up the basis and the model + start_time = np.min(self.psr.toas) + Tspan = np.max(self.psr.toas) - start_time + mn = white_signals.MeasurementNoise(efac=parameter.Constant(1.0), selection=Selection(selections.no_selection)) + crn = gp_signals.FFTBasisCommonGP( + pl, orf, nknots=31, name="gw", oversample=3, cutoff=3, Tspan=Tspan, start_time=start_time + ) + model = mn + crn + pta = signal_base.PTA([model(psr) for psr in [self.psr, self.psr]]) + + # test the prior matrices, including ORF + phi_full = pta.get_phi(params) + phi_1 = phi_full[:31, :31] + phi_12 = phi_full[31:, :31] + phi_2 = phi_full[31:, 31:] + + msg = f"Common mode FFT Prior not equal between pulsars." + assert np.allclose(phi_1, phi_2), msg + assert np.allclose(0.5 * phi_1, phi_12), msg + def test_red_noise_add(self): """Test that red noise addition only returns independent columns.""" # set up signals diff --git a/tests/test_utils.py b/tests/test_utils.py index f85f1a17..9b3ec31f 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -35,6 +35,12 @@ def setUpClass(cls): cls.Fdm, _ = utils.createfourierdesignmatrix_dm(cls.psr.toas, freqs=cls.psr.freqs, nmodes=30) + cls.B, _ = utils.create_fft_time_basis(cls.psr.toas, nknots=30) + + cls.Bdm, _ = utils.create_fft_time_basis_dm(cls.psr.toas, freqs=cls.psr.freqs, nknots=30) + + cls.Bchr, _ = utils.create_fft_time_basis_chromatic(cls.psr.toas, freqs=cls.psr.freqs, nknots=30) + cls.Feph, cls.feph = utils.createfourierdesignmatrix_ephem(cls.psr.toas, cls.psr.pos, nmodes=30) cls.Mm = utils.create_stabletimingdesignmatrix(cls.psr.Mmat) @@ -51,12 +57,30 @@ def test_createfourierdesignmatrix_red(self, nf=30): msg = "Fourier design matrix shape incorrect" assert self.F.shape == (4005, 2 * nf), msg + def test_create_fft_time_basis(self, nk=30): + """Check FFT interpolation design matrix shape.""" + + msg = "FFT interpolation design matrix shape incorrect" + assert self.B.shape == (4005, nk), msg + def test_createfourierdesignmatrix_dm(self, nf=30): """Check DM-variation Fourier design matrix shape.""" msg = "DM-variation Fourier design matrix shape incorrect" assert self.Fdm.shape == (4005, 2 * nf), msg + def test_create_fft_time_basis_dm(self, nk=30): + """Check FFT interpolation design matrix shape.""" + + msg = "DM-variation FFT interpolation design matrix shape incorrect" + assert self.Bdm.shape == (4005, nk), msg + + def test_create_fft_time_basis_chromatic(self, nk=30): + """Check FFT interpolation design matrix shape.""" + + msg = "DM-variation FFT interpolation design matrix shape incorrect" + assert self.Bchr.shape == (4005, nk), msg + def test_createfourierdesignmatrix_ephem(self, nf=30): """Check x-axis ephemeris Fourier design matrix shape."""