Source code for rodeo.inference.magi

import jax
import jax.numpy as jnp
from rodeo.kalmantv import standard
from rodeo.kalmantv import square_root

[docs] def magi_logdens(ode_data_subset, ode_expand, n_active, prior_pars, kalman_type, **params): """ Log-density of MAGI approximation. Args: ode_data_subset (ndarray(n_steps+1, n_block, n_deriv-1)): Array specifying :math:`U_{0:N}`, the subset of the solution process needed to reconstruct the entire solution with `ode_expand()`. ode_expand (Callable): Function taking inputs `ode_data_subset` and `**params` and returning the full solution process :math:`X_{0:N}`. n_active (int): Number of active derivatives -- i.e., not those zero-padded -- for the solution process. prior_pars (tuple): A tuple containing the weight matrix and the variance matrix defining the solution prior; :math:`Q, R`. kalman_type (str): Determine which type of Kalman (standard, square-root) to use. **params (kwargs): Parameters to pass to `ode_expand`. Returns: (float): Value of the logdensity `p(ode_data_subset, Z = 0 | params, prior_pars)`. """ # standard or square-root filter if kalman_type == "standard": kalman_funs = standard elif kalman_type == "square-root": kalman_funs = square_root else: raise NotImplementedError # setup n_vars = ode_data_subset.shape[1] ode_state = ode_expand(ode_data_subset, **params) n_deriv = ode_state.shape[2] # construct `*_meas` parameters wgt_meas = jnp.eye(n_active, n_deriv) wgt_meas = jnp.stack([wgt_meas] * n_vars) mean_meas = jnp.zeros((n_active,)) mean_meas = jnp.stack([mean_meas] * n_vars) var_meas = jnp.zeros((n_active, n_active)) var_meas = jnp.stack([var_meas] * n_vars) # construct remaining `*_state` parameters mean_state = jnp.zeros((n_vars, n_deriv)) wgt_state, var_state = prior_pars # kalman filter def filter_scan(carry, x_meas): mean_state_past, var_state_past = carry["state"] # kalman predict mean_state_pred, var_state_pred = jax.vmap(kalman_funs.predict)( mean_state_past=mean_state_past, var_state_past=var_state_past, mean_state=mean_state, wgt_state=wgt_state, var_state=var_state ) # kalman forecast (for logdens) mean_state_fore, var_state_fore = jax.vmap(kalman_funs.forecast)( mean_state_pred=mean_state_pred, var_state_pred=var_state_pred, mean_meas=mean_meas, wgt_meas=wgt_meas, var_meas=var_meas ) # var_state_fore = jax.vmap(lambda a: a.dot(a.T))(var_state_fore) logdens = jax.vmap(jax.scipy.stats.multivariate_normal.logpdf)( x=x_meas, mean=mean_state_fore, cov=var_state_fore ) # kalman update mean_state_next, var_state_next = jax.vmap(kalman_funs.update)( mean_state_pred=mean_state_pred, var_state_pred=var_state_pred, x_meas=x_meas, mean_meas=mean_meas, wgt_meas=wgt_meas, var_meas=var_meas ) carry["state"] = (mean_state_next, var_state_next) carry["logdens"] = carry["logdens"] + jnp.sum(logdens) return carry, None filter_init = { "state": (ode_state[0], jnp.zeros((n_vars, n_deriv, n_deriv))), "logdens": 0.0 } res, _ = jax.lax.scan( f=filter_scan, init=filter_init, xs=ode_state[1:, :, :n_active] ) return res["logdens"]