Source code for rodeo.inference.basic

r"""
This module implements the Basic method for computing the approximate loglikelihood of :math:`\log p(Y_{0:M} \mid Z_{1:N})`.

Using :math:`\mu_{0:N|N} = E(X_{0:N} \mid Z_{1:N})` from the rodeo solver, the approximate likelihood is computed as

.. math:: p(Y_{0:M} \mid Z_{1:N}) = \sum_{i=0}^M \log p(Y_i \mid X_{n(i)} = \mu_{n(i)|N}).

In the case that observations time grid is not the same as the solver time grid, then the observation uses the closest discretization time point.

"""
import jax
import jax.numpy as jnp
from rodeo.solve import solve_mv


[docs] def basic(key, ode_fun, ode_weight, ode_init, t_min, t_max, n_steps, interrogate, prior_pars, obs_data, obs_times, obs_loglik, kalman_type="standard", **params): r""" Basic algorithm to compute the approximate loglikelihood of :math:`p(Y_{0:M} \mid Z_{1:N})`. Args: key (PRNGKey): PRNG key. ode_fun (Callable): Higher order ODE function :math:`W X_t = F(X_t, t)` taking arguments :math:`X` and :math:`t`. ode_weight (ndarray(n_block, n_bmeas, n_bstate)): Weight matrix defining the measure prior; :math:`W`. ode_init (ndarray(n_block, n_bstate)): Initial value of the state variable :math:`X_t` at time :math:`t = a`. t_min (float): First time point of the time interval to be evaluated; :math:`a`. t_max (float): Last time point of the time interval to be evaluated; :math:`b`. n_steps (int): Number of discretization points (:math:`N`) of the time interval that is evaluated, such that discretization timestep is :math:`dt = (b-a)/N`. interrogate (Callable): Function defining the interrogation method. prior_pars (tuple): A tuple containing the weight matrix and the variance matrix defining the solution prior; :math:`Q, R`. obs_data (ndarray(n_obs, n_bobs)): Observed data; :math:`Y_{0:M}`. obs_times (ndarray(n_obs)): Observation time; :math:`0, \ldots, M`. obs_loglik (Callable): Observation loglikelihood function. kalman_type (str): Determine which type of Kalman (standard, square-root) to use. params (kwargs): Optional model parameters. Returns: (float): The loglikelihood of :math:`p(Y_{0:M} \mid Z_{1:N})`. """ Xt, _ = solve_mv( key=key, ode_fun=ode_fun, ode_weight=ode_weight, ode_init=ode_init, t_min=t_min, t_max=t_max, n_steps=n_steps, interrogate=interrogate, prior_pars=prior_pars, kalman_type=kalman_type, **params ) sim_times = jnp.linspace(t_min, t_max, n_steps+1) ode_data = Xt[jnp.searchsorted(sim_times, obs_times)] return obs_loglik(obs_data, ode_data, **params), Xt