Inference

Basic

This module implements the Basic method for computing the approximate loglikelihood of \(\log p(Y_{0:M} \mid Z_{1:N})\).

Using \(\mu_{0:N|N} = E(X_{0:N} \mid Z_{1:N})\) from the rodeo solver, the approximate likelihood is computed as

\[p(Y_{0:M} \mid Z_{1:N}) = \sum_{i=0}^M \log p(Y_i \mid X_{n(i)} = \mu_{n(i)|N}).\]

In the case that observations time grid is not the same as the solver time grid, then the observation uses the closest discretization time point.

rodeo.inference.basic.basic(key, ode_fun, ode_weight, ode_init, t_min, t_max, n_steps, interrogate, prior_pars, obs_data, obs_times, obs_loglik, kalman_type='standard', **params)[source]

Basic algorithm to compute the approximate loglikelihood of \(p(Y_{0:M} \mid Z_{1:N})\).

Parameters:
  • key (PRNGKey) – PRNG key.

  • ode_fun (Callable) – Higher order ODE function \(W X_t = F(X_t, t)\) taking arguments \(X\) and \(t\).

  • ode_weight (ndarray(n_block, n_bmeas, n_bstate)) – Weight matrix defining the measure prior; \(W\).

  • ode_init (ndarray(n_block, n_bstate)) – Initial value of the state variable \(X_t\) at time \(t = a\).

  • t_min (float) – First time point of the time interval to be evaluated; \(a\).

  • t_max (float) – Last time point of the time interval to be evaluated; \(b\).

  • n_steps (int) – Number of discretization points (\(N\)) of the time interval that is evaluated, such that discretization timestep is \(dt = (b-a)/N\).

  • interrogate (Callable) – Function defining the interrogation method.

  • prior_pars (tuple) – A tuple containing the weight matrix and the variance matrix defining the solution prior; \(Q, R\).

  • obs_data (ndarray(n_obs, n_bobs)) – Observed data; \(Y_{0:M}\).

  • obs_times (ndarray(n_obs)) – Observation time; \(0, \ldots, M\).

  • obs_loglik (Callable) – Observation loglikelihood function.

  • kalman_type (str) – Determine which type of Kalman (standard, square-root) to use.

  • params (kwargs) – Optional model parameters.

Returns:

The loglikelihood of \(p(Y_{0:M} \mid Z_{1:N})\).

Return type:

(float)

DALTON

This module implements the DALTON solver which gives an approximate likelihood of \(p(Y_{0:M} \mid Z_{1:N})\).

The model is

\[ \begin{align}\begin{aligned}x_0 = v\\X_n = c_n + Q_n X_{n-1} + R_n^{1/2} \epsilon_n\\Z_n = W_n X_n - f(X_n, t_n) + V_n^{1/2} \eta_n.\\Y_m = g(X_m, \phi_m)\end{aligned}\end{align} \]

where \(g\) is a general distribution function. In the case that \(g\) is Gaussian, use dalton() for a better approximation. In other cases, use daltonng(). We assume that \(c_n = 0, Q_n = Q, R_n = R\), and \(W_n = W\) for all \(n\).

In the Gaussian case, we assume the observation model is

\[Y_m = D_m X_m + \Omega^{1/2}_m \epsilon_m.\]

We assume that the \(M \leq N\), so that the observation step size is larger than that of the evaluation step size.

rodeo.inference.dalton.dalton(key, ode_fun, ode_weight, ode_init, t_min, t_max, n_steps, interrogate, prior_pars, obs_data, obs_times, obs_weight, obs_var, kalman_type='standard', **params)[source]

Compute marginal loglikelihood of DALTON algorithm for Gaussian observations; \(p(Y_{0:M} \mid Z_{1:N})\).

Parameters:
  • key (PRNGKey) – PRNG key.

  • ode_fun (Callable) – Higher order ODE function \(W X_t = F(X_t, t)\) taking arguments \(X\) and \(t\).

  • ode_weight (ndarray(n_block, n_bmeas, n_bstate)) – Weight matrix defining the measure prior; \(W\).

  • ode_init (ndarray(n_block, n_bstate)) – Initial value of the state variable \(X_t\) at time \(t = a\).

  • t_min (float) – First time point of the time interval to be evaluated; \(a\).

  • t_max (float) – Last time point of the time interval to be evaluated; \(b\).

  • n_steps (int) – Number of discretization points (\(N\)) of the time interval that is evaluated, such that discretization timestep is \(dt = (b-a)/N\).

  • interrogate (Callable) – Function defining the interrogation method.

  • prior_pars (tuple) – A tuple containing the weight matrix and the variance matrix defining the solution prior; \(Q, R\).

  • obs_data (ndarray(n_obs, n_blocks, n_bobs)) – Observed data; \(Y_{0:M}\).

  • obs_times (ndarray(n_obs)) – Observation time; \(0, \ldots, M\).

  • obs_weight (ndarray(n_obs, n_blocks, n_bobs, n_bstate)) – Weight matrix in the observation model; \(D_{0:M}\).

  • obs_var (ndarry(n_obs, n_blocks, n_bobs, n_bobs)) – Variance matrix in the observation model; \(\Omega_{0:M}\)

  • kalman_type (str) – Determine which type of Kalman (standard, square-root) to use.

  • params (kwargs) – Optional model parameters.

Returns:

Loglikelihood of \(p(Y_{0:M} \mid Z_{1:N})\).

Return type:

(float)

rodeo.inference.dalton.daltonng(key, ode_fun, ode_weight, ode_init, t_min, t_max, n_steps, interrogate, prior_pars, obs_data, obs_times, obs_loglik_i, kalman_type='standard', **params)[source]

Compute marginal loglikelihood of DALTON algorithm for non-Gaussian observations; \(p(\hat Y_{0:M} \mid Z_{1:N})\).

Parameters:
  • key (PRNGKey) – PRNG key.

  • ode_fun (Callable) – Higher order ODE function \(W X_t = F(X_t, t)\) taking arguments \(X\) and \(t\).

  • ode_weight (ndarray(n_block, n_bmeas, n_bstate)) – Weight matrix defining the measure prior; \(W\).

  • ode_init (ndarray(n_block, n_bstate)) – Initial value of the state variable \(X_t\) at time \(t = a\).

  • t_min (float) – First time point of the time interval to be evaluated; \(a\).

  • t_max (float) – Last time point of the time interval to be evaluated; \(b\).

  • n_steps (int) – Number of discretization points (\(N\)) of the time interval that is evaluated, such that discretization timestep is \(dt = (b-a)/N\).

  • interrogate (Callable) – Function defining the interrogation method.

  • prior_pars (tuple) – A tuple containing the weight matrix and the variance matrix defining the solution prior; \(Q, R\).

  • obs_data (ndarray(n_obs, n_blocks, n_bobs)) – Observed data; \(Y_{0:M}\).

  • obs_times (ndarray(n_obs)) – Observation time; \(0, \ldots, M\).

  • obs_loglik_i (Callable) – Loglikelihood function for each observation.

  • kalman_type (str) – Determine which type of Kalman (standard, square-root) to use.

  • params (kwargs) – Optional model parameters.

Returns:

Loglikelihood of \(p(\hat Y_{0:M} \mid Z_{1:N})\).

Return type:

(float)

rodeo.inference.dalton.solve_mv(key, ode_fun, ode_weight, ode_init, t_min, t_max, n_steps, interrogate, prior_pars, obs_data, obs_times, obs_weight, obs_var, kalman_type='standard', **params)[source]

DALTON algorithm to compute the mean and variance of \(p(X_{0:N} \mid Y_{0:M}, Z_{1:N})\) assuming Gaussian observations. Same arguments as dalton().

Returns:

  • mean_state_smooth (ndarray(n_steps+1, n_block, n_bstate)): Posterior mean of the solution process \(X_t\) at times \(t \in [a, b]\).

  • var_state_smooth (ndarray(n_steps+1, n_block, n_bstate, n_bstate)): Posterior variance of the solution process at times \(t \in [a, b]\).

Return type:

(tuple)

rodeo.inference.dalton.solve_mv_nn(key, ode_fun, ode_weight, ode_init, t_min, t_max, n_steps, interrogate, prior_pars, obs_data, obs_times, obs_loglik_i, kalman_type='standard', **params)[source]

DALTON algorithm to compute the mean and variance of \(p(X_{0:N} \mid \hat Y_{0:M}, Z_{1:N})\) assuming non-Gaussian observations. Same arguments as daltonng().

Returns:

  • mean_state_smooth (ndarray(n_steps+1, n_block, n_bstate)): Posterior mean of the solution process \(X_t\) at times \(t \in [a, b]\).

  • var_state_smooth (ndarray(n_steps+1, n_block, n_bstate, n_bstate)): Posterior variance of the solution process at times \(t \in [a, b]\).

Return type:

(tuple)

rodeo.inference.dalton.solve_sim(key, ode_fun, ode_weight, ode_init, t_min, t_max, n_steps, interrogate, prior_pars, obs_data, obs_times, obs_weight, obs_var, kalman_type='standard', **params)[source]

DALTON algorithm to sample from \(p(X_{0:N} \mid Y_{0:M}, Z_{1:N})\) assuming Gaussian observations. Same arguments as dalton().

Returns:

  • mean_state_smooth (ndarray(n_steps+1, n_block, n_bstate)): Posterior mean of the solution process \(X_t\) at times \(t \in [a, b]\).

  • var_state_smooth (ndarray(n_steps+1, n_block, n_bstate, n_bstate)): Posterior variance of the solution process at times \(t \in [a, b]\).

Return type:

(tuple)

Fenrir

This module implements the Fenrir algorithm as described in Tronarp et al 2022 for computing the approximate likelihood of \(p(Y_{0:M} \mid Z_{1:N})\).

The forward pass model is

\[ \begin{align}\begin{aligned}x_0 = v\\X_n = c_n + Q_n X_{n-1} + R_n^{1/2} \epsilon_n\\Z_n = W_n X_n - f(X_n, t_n) + V_n^{1/2} \eta_n.\end{aligned}\end{align} \]

We assume that \(c_n = 0, Q_n = Q, R_n = R\), and \(W_n = W\) for all \(n\). Using the Kalman filtering recursions, the above model can be simulated via the reverse pass model

\[ \begin{align}\begin{aligned}X_N \sim \operatorname{Normal}(b_N, C_N)\\X_n = A_n X_{n+1} + b_n + C_n^{1/2} \epsilon_n.\end{aligned}\end{align} \]

Fenrir combines the observations

\[Y_m = D_m X_m + \Omega^{1/2}_m \eta_m,\]

with the reverse pass model to condition on data. Here \(\epsilon_n, \eta_m\) are standard normals.

rodeo.inference.fenrir.fenrir(key, ode_fun, ode_weight, ode_init, t_min, t_max, n_steps, interrogate, prior_pars, obs_data, obs_times, obs_weight, obs_var, kalman_type='standard', **params)[source]

Fenrir algorithm to compute the approximate loglikelihood of \(p(Y_{0:M} \mid Z_{1:N})\).

Parameters:
  • key (PRNGKey) – PRNG key.

  • ode_fun (Callable) – Higher order ODE function \(W X_t = F(X_t, t)\) taking arguments \(X\) and \(t\).

  • ode_weight (ndarray(n_block, n_bmeas, n_bstate)) – Weight matrix defining the measure prior; \(W\).

  • ode_init (ndarray(n_block, n_bstate)) – Initial value of the state variable \(X_t\) at time \(t = a\).

  • t_min (float) – First time point of the time interval to be evaluated; \(a\).

  • t_max (float) – Last time point of the time interval to be evaluated; \(b\).

  • n_steps (int) – Number of discretization points (\(N\)) of the time interval that is evaluated, such that discretization timestep is \(dt = (b-a)/N\).

  • interrogate (Callable) – Function defining the interrogation method.

  • prior_pars (tuple) – A tuple containing the weight matrix and the variance matrix defining the solution prior; \(Q, R\).

  • obs_data (ndarray(n_obs, n_blocks, n_bobs)) – Observed data; \(y_{0:M}\).

  • obs_times (ndarray(n_obs)) – Observation time; \(0, \ldots, M\).

  • obs_weight (ndarray(n_obs, n_blocks, n_bobs, n_bstate)) – Weight matrix in the observation model; \(D_{0:M}\).

  • obs_var (ndarry(n_obs, n_blocks, n_bobs, n_bobs)) – Variance matrix in the observation model; \(\Omega_{0:M}\)

  • kalman_type (str) – Determine which type of Kalman (standard, square-root) to use.

  • params (kwargs) – Optional model parameters.

Returns:

The loglikelihood of \(p(y_{0:M} \mid Z_{1:N})\).

Return type:

(float)

rodeo.inference.fenrir.solve_mv(key, ode_fun, ode_weight, ode_init, t_min, t_max, n_steps, interrogate, prior_pars, obs_data, obs_times, obs_weight, obs_var, kalman_type='standard', **params)[source]

Fenrir algorithm to compute the mean and variance of \(p(X_{0:N} \mid Z_{1:N}, Y_{0:M})\). Same arguments as fenrir().

Returns:

  • mean_state_smooth (ndarray(n_steps+1, n_block, n_bstate)): Posterior mean of the solution process \(X_t\) at times \(t \in [a, b]\).

  • var_state_smooth (ndarray(n_steps+1, n_block, n_bstate, n_bstate)): Posterior variance of the solution process at times \(t \in [a, b]\).

Return type:

(tuple)

MAGI

rodeo.inference.magi.magi_logdens(ode_data_subset, ode_expand, n_active, prior_pars, kalman_type, **params)[source]

Log-density of MAGI approximation.

Parameters:
  • ode_data_subset (ndarray(n_steps+1, n_block, n_deriv-1)) – Array specifying \(U_{0:N}\), the subset of the solution process needed to reconstruct the entire solution with ode_expand().

  • ode_expand (Callable) – Function taking inputs ode_data_subset and **params and returning the full solution process \(X_{0:N}\).

  • n_active (int) – Number of active derivatives – i.e., not those zero-padded – for the solution process.

  • prior_pars (tuple) – A tuple containing the weight matrix and the variance matrix defining the solution prior; \(Q, R\).

  • kalman_type (str) – Determine which type of Kalman (standard, square-root) to use.

  • **params (kwargs) – Parameters to pass to ode_expand.

Returns:

Value of the logdensity p(ode_data_subset, Z = 0 | params, prior_pars).

Return type:

(float)

Marginal MCMC

BlackJAX implementation of pseudomarginal MCMC with Random Walk kernels.

The API is nearly identical to that provided by blackjax.mcmc.random_walk.py. The main differences are:

  1. logdensity_fn takes two arguments: the PyTree defining the marginal random variables (i.e., the position)

    and a PRNG key to perform any random sampling inside logdensity_fn to obtain the stochastic estimate of the log-density.

  2. The return value of logdensity_fn is a tuple of which the first element is the stochastic log-density estimate, and the second are auxiliary variables. Most commonly, these would be the latent variables which are stochastically marginalized over.

    In this implementation, the second return value is mandatory. If no auxiliary variables are needed then set this tuple element to None.

The remainder of this docstring is copied from blackjax.mcmc.random_walk.py.

Some interfaces are exposed here for convenience and for entry level users, who might be familiar with simpler versions of the algorithms, but in all cases they are particular instantiations of the Random Walk Rosenbluth-Metropolis-Hastings.

Let’s note \(x_{t-1}\) to the previous position and \(x_t\) to the newly sampled one.

The variants offered are:

  1. Proposal distribution as addition of random noice from previous position. This means x_t = x_{t-1} + step.

    Function: additive_step()

  2. Independent proposal distribution: \(P(x_t)\) doesn’t depend on \(x_{t_1}\).

    Function: irmh()

  3. Proposal distribution using a symmetric function. That means \(P(x_t|x_{t-1}) = P(x_{t-1}|x_t)\). See ‘Metropolis Algorithm’ in [1].

    Function: rmh() without proposal_logdensity_fn.

  4. Asymmetric proposal distribution. See ‘Metropolis-Hastings’ Algorithm in [1].

    Function: rmh() with proposal_logdensity_fn.

Reference: Andrew Gelman, John B Carlin, Hal S Stern, and Donald B Rubin. Bayesian data analysis. Chapman and Hall/CRC, 2014. Section 11.2

Example

The simplest case is:

random_walk = blackjax.additive_step_random_walk(logdensity_fn, blackjax.mcmc.random_walk.normal(sigma))
state = random_walk.init(position)
new_state, info = random_walk.step(rng_key, state)

In all cases we can JIT-compile the step function for better performance

step = jax.jit(random_walk.step)
new_state, info = step(rng_key, state)
rodeo.inference.pseudo_marginal.additive_step_random_walk(logdensity_fn: Callable, random_step: Callable) SamplingAlgorithm[source]

Implements the user interface for the Additive Step RMH

Example

A new kernel can be initialized and used with the following code:

rw = blackjax.additive_step_random_walk(logdensity_fn, random_step)
state = rw.init(position)
new_state, info = rw.step(rng_key, state)

The specific case of a Gaussian random_step is already implemented, either with independent components when covariance_matrix is a one dimensional array or with dependent components if a two dimensional array:

rw_gaussian = blackjax.additive_step_random_walk.normal_random_walk(logdensity_fn, covariance_matrix)
state = rw_gaussian.init(position)
new_state, info = rw_gaussian.step(rng_key, state)
Parameters:
  • logdensity_fn (Callable) – Function to compute the log-probability density of the distribution.

  • random_step (Callable) – A function that generates a step to be added to the current state. This function takes a PRNG key and the current position as input and returns a new proposal step.

Returns:

A sampling algorithm with init and step methods to perform RMH sampling.

Return type:

(SamplingAlgorithm)

rodeo.inference.pseudo_marginal.build_additive_step()[source]

Build a Random Walk Rosenbluth-Metropolis-Hastings (RMH) kernel using an additive step proposal.

Returns:

A function that takes a random key and chain state, performs an RMH step, and returns the new state and transition info.

Return type:

(Callable)

rodeo.inference.pseudo_marginal.build_irmh() Callable[source]

Build an Independent Random Walk Rosenbluth-Metropolis-Hastings (RMH) kernel.

This kernel uses a proposal distribution that is independent of the current state, i.e., the new proposed state is sampled independently of the particle’s current position.

Returns:

A function (kernel) that takes a PRNG key and a PyTree containing the current state of the chain and that returns a new state of the chain along with information about the transition.

Return type:

(Callable)

rodeo.inference.pseudo_marginal.build_rmh()[source]

Build a Rosenbluth-Metropolis-Hastings kernel.

Returns:

A function (kernel) that takes a PRNG key and a PyTree containing the current state of the chain and that returns a new state of the chain along with information about the transition.

Return type:

(Callable)

rodeo.inference.pseudo_marginal.build_rmh_transition_energy(proposal_logdensity_fn: Callable | None) Callable[source]
rodeo.inference.pseudo_marginal.init(position: Array | ndarray | bool | number | bool | int | float | complex | Iterable[Array | ndarray | bool | number | bool | int | float | complex | Iterable[ArrayLikeTree] | Mapping[Any, ArrayLikeTree]] | Mapping[Any, Array | ndarray | bool | number | bool | int | float | complex | Iterable[ArrayLikeTree] | Mapping[Any, ArrayLikeTree]], logdensity_fn: Callable, rng_key: Array) RWAState[source]

Create an initial chain state from a given position.

Parameters:
  • position (ArrayLikeTree) – The initial position of the chain.

  • logdensity_fn (Callable) – Function to compute the log-probability density of the distribution.

Returns:

The initialized state of the chain.

Return type:

(RWAState)

rodeo.inference.pseudo_marginal.irmh_as_top_level_api(logdensity_fn: Callable, proposal_distribution: Callable, proposal_logdensity_fn: Callable | None = None) SamplingAlgorithm[source]

Implements the (basic) user interface for the independent RMH.

Example

A new kernel can be initialized and used with the following code:

rmh = blackjax.irmh(logdensity_fn, proposal_distribution)
state = rmh.init(position)
new_state, info = rmh.step(rng_key, state)

We can JIT-compile the step function for better performance

step = jax.jit(rmh.step)
new_state, info = step(rng_key, state)
Parameters:
  • logdensity_fn (Callable) – The log-probability density function of the distribution to sample from.

  • proposal_distribution (Callable) – A function that takes a PRNG key and produces a new proposal. The proposal is independent of the current state of the sampler.

  • proposal_logdensity_fn (Optional[Callable]) – A function that returns the log-density of obtaining a given proposal, given the current state. This is required for non-symmetric proposals. If not provided, the proposal is assumed to be symmetric.

Returns:

An object containing init and step methods for performing Independent Random Walk Metropolis-Hastings sampling.

Return type:

(SamplingAlgorithm)

rodeo.inference.pseudo_marginal.normal_random_walk(logdensity_fn: Callable, sigma)[source]

Create a Gaussian additive step random walk Metropolis-Hastings sampler.

This method initializes a random walk sampler with Gaussian-distributed steps.

Parameters:
  • logdensity_fn (Callable) – Function to compute the log-probability density of the distribution.

  • sigma (ArrayLikeTree) – Standard deviation of the Gaussian distribution used for the proposal steps.

Returns:

An object with init and step methods to run the Gaussian RMH sampler.

Return type:

(SamplingAlgorithm)

rodeo.inference.pseudo_marginal.rmh_as_top_level_api(logdensity_fn: Callable, proposal_generator: Callable[[Array, Array | ndarray | bool | number | bool | int | float | complex | Iterable[Array | ndarray | bool | number | bool | int | float | complex | Iterable[ArrayLikeTree] | Mapping[Any, ArrayLikeTree]] | Mapping[Any, Array | ndarray | bool | number | bool | int | float | complex | Iterable[ArrayLikeTree] | Mapping[Any, ArrayLikeTree]]], Array | Iterable[Array | Iterable[ArrayTree] | Mapping[Any, ArrayTree]] | Mapping[Any, Array | Iterable[ArrayTree] | Mapping[Any, ArrayTree]]], proposal_logdensity_fn: Callable[[Array | ndarray | bool | number | bool | int | float | complex | Iterable[Array | ndarray | bool | number | bool | int | float | complex | Iterable[ArrayLikeTree] | Mapping[Any, ArrayLikeTree]] | Mapping[Any, Array | ndarray | bool | number | bool | int | float | complex | Iterable[ArrayLikeTree] | Mapping[Any, ArrayLikeTree]]], Array | Iterable[Array | Iterable[ArrayTree] | Mapping[Any, ArrayTree]] | Mapping[Any, Array | Iterable[ArrayTree] | Mapping[Any, ArrayTree]]] | None = None) SamplingAlgorithm[source]

Implements the user interface for the RMH.

Example

A new kernel can be initialized and used with the following code:

rmh = blackjax.rmh(logdensity_fn, proposal_generator)
state = rmh.init(position)
new_state, info = rmh.step(rng_key, state)

We can JIT-compile the step function for better performance

step = jax.jit(rmh.step)
new_state, info = step(rng_key, state)

Create a user interface for the Rosenbluth-Metropolis-Hastings (RMH) sampler.

This function returns a SamplingAlgorithm object that provides init and step methods for performing RMH sampling. The user can specify a custom proposal generator and an optional log-density function for non-symmetric proposals.

Parameters:
  • logdensity_fn (Callable) – The log-probability density function of the distribution to sample from.

  • proposal_generator (Callable) – A function that takes a random number generator key and the current state, then generates a new proposal.

  • proposal_logdensity_fn (Optional[Callable]) – The log-density function associated with the proposal generator. If the proposal distribution is non-symmetric (i.e., P(x_t | x_{t-1}) ≠ P(x_{t-1} | x_t)), this must be provided to apply the Metropolis-Hastings correction for detailed balance.

Returns:

An object containing init and step methods for running the RMH sampler.

Return type:

(SamplingAlgorithm)

rodeo.inference.pseudo_marginal.rmh_proposal(logdensity_fn: ~typing.Callable, transition_distribution: ~typing.Callable, compute_acceptance_ratio: ~typing.Callable, sample_proposal: ~typing.Callable = <function static_binomial_sampling>) Callable[source]
Parameters:
  • logdensity_fn (Callable) – The log-probability density function of the distribution to sample from.

  • transition_distribution (Callable) – A function that takes a random number generator key and the current state, then generates a new proposal.

  • compute_acceptance_ratio (Callable) – A function to compute the acceptance ratio.

  • sample_proposal (Callable) – A function to generate the next sample given proposal and previous state.

Returns:

Generator for sample proposals.

Return type:

(Callable)