Skip to content

documentation: explain the backend design #217

@theorashid

Description

@theorashid

I found this type of page really helpful when extended GPJax in the past.

Explaining

  • Filter and Smoother interfaces, which are common throughout
  • build_filter is where the complexity comes in, as each method requires its own parts of a state space model
  • Why there is no explicit update/predict steps (textbook for Kalman) and filter_prepare is used instead (might be for Add "Conventions" page to docs #189)
    • And an explanation of how the parallelisation works

Below is a bit of a dump of some notes I made while going through the package which cover these three points. They might be helpful for us to copy and paste parts, so I'm putting them here. If not, you can stop reading.


Filter and Smoother

The two central inference objects are:

  1. Filter -- a NamedTuple with three callables + a flag:
    init_prepare(model_inputs, key) -- creates the initial state (time 0)
    filter_prepare(model_inputs, key) -- converts model inputs at time t into a prepared state
    filter_combine(state_1, state_2) -- combines a previous state with a prepared state to produce the filtered state
    associative: bool -- if True, filter_combine is associative, enabling parallel filtering via jax.lax.associative_scan

  2. Smoother -- same pattern but runs backwards in time:
    convert_filter_to_smoother_state -- converts the final filter state to seed the backward pass
    smoother_prepare(filter_state, model_inputs, key) -- prepares a state for the smoother
    smoother_combine(state_1, state_2) -- combines two smoother states
    associative: bool -- enables parallel smoothing

Three families of inference methods plug into this interface:

  1. cuthbert.discrete - Hidden Markov Models (forward-backward / Baum-Welch)
  2. cuthbert.gaussian - Kalman filters/smoothers (standard, extended via Taylor linearisation, unscented/ensemble via moment transforms)
  3. cuthbert.smc - Sequential Monte Carlo / particle filters

build_filter

build_filter is different for each method (Kalman, EKF, particle filter, etc.) because each inference method requires different things

  • Kalman methods work with means and covariances (first and second moments)
    • gaussian.kalman.build_filter: get_init_params, get_dynamics_params, get_observation_params (returning matrices F, H, etc.)
    • gaussian.taylor.build_filter: get_init_log_density, get_dynamics_log_density, get_observation_func (returning log density callables + linearisation points)
    • gaussian.moments.build_filter: similar but returning mean/chol_cov callables
  • Particle methods work with samples and weights
    • smc.particle_filter.build_filter: init_sample, propagate_sample, log_potential, plus tuning parameters (n_particles, resampling scheme, ESS threshold)
  • Discrete methods work with probability vectors and transition matrices

parallel (associative) Kalman filter

(Notation below is from Kevin Murphy Advanced Topics in ProbML. I added them to a general Kalman filtering note.)

The standard Kalman filter is sequential – each step depends on the previous filtered mean and covariance, giving $O(T)$ serial complexity. The parallel Kalman filter reformulates the predict + update step as an affine map on the filtered mean, enabling a parallel scan in $O(\log T)$.

Reformulation as an affine map

Substituting standard the time update (predict) into the measurement update, the filtered mean at time $t$ is:

$$ \begin{align} \boldsymbol{\mu}_{t|t} &= \boldsymbol{\mu}_{t|t-1} + \mathbf{K}_t (\mathbf{y}_t - \mathbf{H}_t \boldsymbol{\mu}_{t|t-1} - \mathbf{d}_t) \\ &= (\mathbf{I} - \mathbf{K}_t \mathbf{H}_t)(\mathbf{F}_t \boldsymbol{\mu}_{t-1|t-1} + \mathbf{c}_t) + \mathbf{K}_t (\mathbf{y}_t - \mathbf{d}_t) \\ &= (\mathbf{I} - \mathbf{K}_t \mathbf{H}_t) \mathbf{F}_t \boldsymbol{\mu}_{t-1|t-1} + (\mathbf{I} - \mathbf{K}_t \mathbf{H}_t) \mathbf{c}_t + \mathbf{K}_t (\mathbf{y}_t - \mathbf{d}_t) \end{align} $$

This has the form of an affine map on the previous filtered mean:

$$ \boldsymbol{\mu}_{t|t} = \mathbf{A}_t \boldsymbol{\mu}_{t-1|t-1} + \mathbf{b}_t $$

where:

$$ \begin{align} \mathbf{A}_t &= (\mathbf{I} - \mathbf{K}_t \mathbf{H}_t) \mathbf{F}_t \\ \mathbf{b}_t &= (\mathbf{I} - \mathbf{K}_t \mathbf{H}_t) \mathbf{c}_t + \mathbf{K}_t (\mathbf{y}_t - \mathbf{d}_t) \end{align} $$

The covariance update $\boldsymbol{\Sigma}_{t|t}$ does not depend on the mean – it only depends on the model parameters $(\mathbf{F}_t, \mathbf{Q}_t, \mathbf{H}_t, \mathbf{R}_t)$ and can be computed independently.

associative scan

The composition of two affine maps is itself an affine map:

$$ f_j \circ f_i(\boldsymbol{\mu}) = \mathbf{A}_j(\mathbf{A}_i \boldsymbol{\mu} + \mathbf{b}_i) + \mathbf{b}_j = (\mathbf{A}_j \mathbf{A}_i) \boldsymbol{\mu} + (\mathbf{A}_j \mathbf{b}_i + \mathbf{b}_j) $$

This composition is associative (it is function composition). To compute all $T$ filtered states, we need the prefix compositions:

$$ f_1, \quad f_2 \circ f_1, \quad f_3 \circ f_2 \circ f_1, \quad \ldots, \quad f_T \circ \cdots \circ f_1 $$

The associative scan (parallel prefix scan) computes all of these in $O(\log T)$ parallel steps by composing pairs in a binary tree:

$$ \begin{align} \text{Level 0:} \quad & e_0, \quad e_1, \quad e_2, \quad e_3 \\ \text{Level 1:} \quad & e_0, \quad e_{0:1}, \quad e_2, \quad e_{2:3} \\ \text{Level 2:} \quad & e_0, \quad e_{0:1}, \quad e_{0:2}, \quad e_{0:3} \end{align} $$

where $e_{i:j}$ denotes the composed affine map from step $i$ to $j$. At each level, independent compositions run in parallel. Associativity guarantees that any grouping produces the correct result.

cuthbert implements this using jax.lax.associative_sca(https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.associative_scan.html) with parallel=True flag:

from cuthbert import filter
from cuthbert.gaussian import kalman

filter_obj = kalman.build_filter(get_init_params, get_dynamics_params, get_observation_params)
states = filter(filter_obj, model_inputs, parallel=True)

cuthbert has three functions that map to the maths above:

  • init_prepare: creates the initial element with $\mathbf{A}_0 = \mathbf{0}$, $\mathbf{b}_0 = \boldsymbol{\mu}_0$
  • filter_prepare: encodes one time step's parameters $(\mathbf{F}_t, \mathbf{Q}_t, \mathbf{H}_t, \mathbf{R}_t, \mathbf{y}_t)$ into the affine scan element $(\mathbf{A}_t, \mathbf{b}_t, \ldots)$ via associative_params_single
  • filter_combine: composes two scan elements via the associative filtering_operator

In parallel mode, all elements are prepared independently via jax.vmap(filter_prepare), then composed via jax.lax.associative_scan(filter_combine, ...). In sequential mode, the same filter_prepare and filter_combine are called inside a jax.lax.scan loop.

Metadata

Metadata

Assignees

No one assigned

    Labels

    documentationImprovements or additions to documentation

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions