Source code for rodeo.inference.dalton

r"""
This module implements the DALTON solver which gives an approximate likelihood of :math:`p(Y_{0:M} \mid Z_{1:N})`.

The model is

.. math::

    x_0 = v

    X_n = c_n + Q_n X_{n-1} + R_n^{1/2} \epsilon_n

    Z_n = W_n X_n - f(X_n, t_n) + V_n^{1/2} \eta_n.
    
    Y_m = g(X_m, \phi_m)

where :math:`g` is a general distribution function. In the case that :math:`g` is Gaussian, use :func:`dalton` for a better approximation. In other cases, use :func:`daltonng`. We assume that :math:`c_n = 0, Q_n = Q, R_n = R`, and :math:`W_n = W` for all :math:`n`.

In the Gaussian case, we assume the observation model is

.. math::

    Y_m = D_m X_m + \Omega^{1/2}_m \epsilon_m.

We assume that the :math:`M \leq N`, so that the observation step size is larger than that of the evaluation step size.

"""
import jax
import jax.numpy as jnp
import jax.scipy as jsp
from rodeo.kalmantv import standard
from rodeo.kalmantv import square_root
from rodeo.inference.fenrir import _forecast_update
from rodeo.utils import multivariate_normal_logpdf
from rodeo.solve import _solve_filter as _solve_filter_ode

# --- loglikelihood -----------------------------------------------------------


[docs] def dalton(key, ode_fun, ode_weight, ode_init, t_min, t_max, n_steps, interrogate, prior_pars, obs_data, obs_times, obs_weight, obs_var, kalman_type="standard", **params): r""" Compute marginal loglikelihood of DALTON algorithm for Gaussian observations; :math:`p(Y_{0:M} \mid Z_{1:N})`. Args: key (PRNGKey): PRNG key. ode_fun (Callable): Higher order ODE function :math:`W X_t = F(X_t, t)` taking arguments :math:`X` and :math:`t`. ode_weight (ndarray(n_block, n_bmeas, n_bstate)): Weight matrix defining the measure prior; :math:`W`. ode_init (ndarray(n_block, n_bstate)): Initial value of the state variable :math:`X_t` at time :math:`t = a`. t_min (float): First time point of the time interval to be evaluated; :math:`a`. t_max (float): Last time point of the time interval to be evaluated; :math:`b`. n_steps (int): Number of discretization points (:math:`N`) of the time interval that is evaluated, such that discretization timestep is :math:`dt = (b-a)/N`. interrogate (Callable): Function defining the interrogation method. prior_pars (tuple): A tuple containing the weight matrix and the variance matrix defining the solution prior; :math:`Q, R`. obs_data (ndarray(n_obs, n_blocks, n_bobs)): Observed data; :math:`Y_{0:M}`. obs_times (ndarray(n_obs)): Observation time; :math:`0, \ldots, M`. obs_weight (ndarray(n_obs, n_blocks, n_bobs, n_bstate)): Weight matrix in the observation model; :math:`D_{0:M}`. obs_var (ndarry(n_obs, n_blocks, n_bobs, n_bobs)): Variance matrix in the observation model; :math:`\Omega_{0:M}` kalman_type (str): Determine which type of Kalman (standard, square-root) to use. params (kwargs): Optional model parameters. Returns: (float): Loglikelihood of :math:`p(Y_{0:M} \mid Z_{1:N})`. """ # Dimensions of block, state and measure variables n_block, n_bmeas, n_bstate = ode_weight.shape # Dimension of observation n_bobs = obs_weight.shape[2] # standard or square-root filter if kalman_type == "standard": kalman_funs = standard elif kalman_type == "square-root": kalman_funs = square_root else: raise NotImplementedError # prior variables prior_weight, prior_var = prior_pars # insert observations on solver time grid sim_times = jnp.linspace(t_min, t_max, n_steps + 1) obs_ind = jnp.searchsorted(sim_times, obs_times) # arguments for kalman_filter and kalman_smooth x_meas = jnp.zeros((n_block, n_bmeas)) obs_mean = jnp.zeros((n_block, n_bobs)) mean_state = jnp.zeros((n_block, n_bstate)) mean_state_init = ode_init var_state_init = jnp.zeros((n_block, n_bstate, n_bstate)) # forecast function without kalman_funs forecast_update = lambda mean_state_pred, var_state_pred,\ x_meas, mean_meas, wgt_meas, var_meas\ : _forecast_update(mean_state_pred, var_state_pred, x_meas, mean_meas, wgt_meas, var_meas, kalman_funs) # compute p(Z_{1:N}, Y_{0:M}) def scan(carry, filter_kwargs): mean_state_filt_zy, var_state_filt_zy = carry["state_filt_joint"] mean_state_filt_z, var_state_filt_z = carry["state_filt_marg"] logdens_zy = carry["logdens_joint"] logdens_z = carry["logdens_marg"] t = filter_kwargs["t"] keys = filter_kwargs["key"] i = carry["i"] ode_time = t_min + (t_max-t_min)*(t+1)/n_steps # compute joint logpdf mean_state_pred_zy, var_state_pred_zy = jax.vmap(kalman_funs.predict)( mean_state_past=mean_state_filt_zy, var_state_past=var_state_filt_zy, mean_state=mean_state, wgt_state=prior_weight, var_state=prior_var ) # compute meas parameters wgt_meas, mean_meas, var_meas = interrogate( key=keys[0], ode_fun=ode_fun, ode_weight=ode_weight, t=ode_time, mean_state_pred=mean_state_pred_zy, var_state_pred=var_state_pred_zy, **params ) W_meas = ode_weight + wgt_meas # both z and y are observed def zy_update(): wgt_meas_obs = jnp.concatenate([W_meas, obs_weight[i]], axis=1) mean_meas_obs = jnp.concatenate([mean_meas, obs_mean], axis=1) var_meas_obs = jax.vmap(jsp.linalg.block_diag)(var_meas, obs_var[i]) x_meas_obs = jnp.concatenate([x_meas, obs_data[i]], axis=1) logp, mean_state_next, var_state_next = jax.vmap(forecast_update)( mean_state_pred=mean_state_pred_zy, var_state_pred=var_state_pred_zy, x_meas=x_meas_obs, mean_meas=mean_meas_obs, wgt_meas=wgt_meas_obs, var_meas=var_meas_obs ) return mean_state_next, var_state_next, jnp.sum(logp), i+1 # only z is observed def z_update(): logp, mean_state_next, var_state_next = jax.vmap(forecast_update)( mean_state_pred=mean_state_pred_zy, var_state_pred=var_state_pred_zy, x_meas=x_meas, mean_meas=mean_meas, wgt_meas=W_meas, var_meas=var_meas ) return mean_state_next, var_state_next, jnp.sum(logp), i mean_state_next_zy, var_state_next_zy, logp, i = jax.lax.cond(t+1 == obs_ind[i], zy_update, z_update) logdens_zy += logp # compute marginal logpdf mean_state_pred_z, var_state_pred_z = jax.vmap(kalman_funs.predict)( mean_state_past=mean_state_filt_z, var_state_past=var_state_filt_z, mean_state=mean_state, wgt_state=prior_weight, var_state=prior_var ) # compute meas parameters wgt_meas, mean_meas, var_meas = interrogate( key=keys[1], ode_fun=ode_fun, ode_weight=ode_weight, t=ode_time, mean_state_pred=mean_state_pred_z, var_state_pred=var_state_pred_z, **params ) W_meas = ode_weight + wgt_meas # kalman forecast and update logp, mean_state_next_z, var_state_next_z = jax.vmap(forecast_update)( mean_state_pred=mean_state_pred_z, var_state_pred=var_state_pred_z, x_meas=x_meas, mean_meas=mean_meas, wgt_meas=W_meas, var_meas=var_meas ) logdens_z += jnp.sum(logp) # carry over state carry = { "state_filt_joint": (mean_state_next_zy, var_state_next_zy), "state_filt_marg": (mean_state_next_z, var_state_next_z), "logdens_joint": logdens_zy, "logdens_marg": logdens_z, "i": i } return carry, None # compute log-density of p(Y_0 |X_0) if Y_0 is at time 0 def _logy0(): logdens_zy = jnp.sum( jax.vmap(lambda b: multivariate_normal_logpdf(obs_data[0, b], mean=obs_weight[0, b].dot(ode_init[b]) + obs_mean[b], cov=obs_var[0, b]) )(jnp.arange(n_block))) return logdens_zy, 1 def _no_logy0(): return 0.0, 0 logdens_zy, i = jax.lax.cond(obs_ind[0] == 0, _logy0, _no_logy0) scan_init = { "state_filt_joint": (mean_state_init, var_state_init), "state_filt_marg": (mean_state_init, var_state_init), "logdens_joint": logdens_zy, "logdens_marg": 0.0, "i": i } if key is not None: keys = jax.random.split(key, num=(n_steps, 2)) else: keys = jnp.zeros((n_steps, 2)) filter_kwargs = { "t": jnp.arange(n_steps), "key": keys } out, _ = jax.lax.scan(scan, scan_init, filter_kwargs) return out["logdens_joint"] - out["logdens_marg"]
# --- ODE solver -------------------------------------------------------------- # use linearizations and observations def _solve_filter(key, ode_fun, ode_weight, ode_init, t_min, t_max, n_steps, interrogate, prior_weight, prior_var, obs_data, obs_times, obs_weight, obs_var, kalman_funs, **params): r""" Forward pass of the DALTON algorithm with Gaussian observations. Same arguments as :func:`dalton`. Returns: (tuple): - **mean_state_pred** (ndarray(n_steps+1, n_block, n_bstate)): Mean estimate for state at time t given observations from times [a...t-1] for :math:`t \in [a, b]`. - **var_state_pred** (ndarray(n_steps+1, n_block, n_bstate, n_bstate)): Variance estimate for state at time t given observations from times [a...t-1] for :math:`t \in [a, b]`. - **mean_state_filt** (ndarray(n_steps+1, n_block, n_bstate)): Mean estimate for state at time t given observations from times [a...t] for :math:`t \in [a, b]`. - **var_state_filt** (ndarray(n_steps+1, n_block, n_bstate, n_bstate)): Variance estimate for state at time t given observations from times [a...t] for :math:`t \in [a, b]`. """ # Dimensions of block, state and measure variables n_block, n_bmeas, n_bstate = ode_weight.shape # Dimension of observation n_bobs = obs_weight.shape[2] # insert observations on solver time grid sim_times = jnp.linspace(t_min, t_max, n_steps + 1) obs_ind = jnp.searchsorted(sim_times, obs_times) # arguments for kalman_filter and kalman_smooth x_meas = jnp.zeros((n_block, n_bmeas)) obs_mean = jnp.zeros((n_block, n_bobs)) mean_state = jnp.zeros((n_block, n_bstate)) mean_state_init = ode_init var_state_init = jnp.zeros((n_block, n_bstate, n_bstate)) # compute p(X_{1:n} | Z_{1:n}, Y_{0:m}) def scan_fun(carry, filter_kwargs): mean_state_filt, var_state_filt = carry["state_filt"] i = carry["i"] t = filter_kwargs["t"] key = filter_kwargs["key"] ode_time = t_min + (t_max-t_min)*(t+1)/n_steps # kalman predict mean_state_pred, var_state_pred = jax.vmap(kalman_funs.predict)( mean_state_past=mean_state_filt, var_state_past=var_state_filt, mean_state=mean_state, wgt_state=prior_weight, var_state=prior_var ) # compute meas parameters wgt_meas, mean_meas, var_meas = interrogate( key=key, ode_fun=ode_fun, ode_weight=ode_weight, t=ode_time, mean_state_pred=mean_state_pred, var_state_pred=var_state_pred, **params ) W_meas = ode_weight + wgt_meas # both z and y are observed def zy_update(): wgt_meas_obs = jnp.concatenate([W_meas, obs_weight[i]], axis=1) mean_meas_obs = jnp.concatenate([mean_meas, obs_mean], axis=1) var_meas_obs = jax.vmap(jsp.linalg.block_diag)(var_meas, obs_var[i]) x_meas_obs = jnp.concatenate([x_meas, obs_data[i]], axis=1) mean_state_next, var_state_next = jax.vmap(kalman_funs.update)( mean_state_pred=mean_state_pred, var_state_pred=var_state_pred, x_meas=x_meas_obs, mean_meas=mean_meas_obs, wgt_meas=wgt_meas_obs, var_meas=var_meas_obs ) return mean_state_next, var_state_next, i+1 # only z is observed def z_update(): mean_state_next, var_state_next = jax.vmap(kalman_funs.update)( mean_state_pred=mean_state_pred, var_state_pred=var_state_pred, x_meas=x_meas, mean_meas=mean_meas, wgt_meas=W_meas, var_meas=var_meas ) return mean_state_next, var_state_next, i mean_state_next, var_state_next, i = jax.lax.cond(t+1 == obs_ind[i], zy_update, z_update) # output carry = { "state_filt": (mean_state_next, var_state_next), "i": i } stack = { "state_filt": (mean_state_next, var_state_next), "state_pred": (mean_state_pred, var_state_pred) } return carry, stack # check if observations start at 0 i = jax.lax.cond(obs_ind[0] == 0, lambda: 1, lambda: 0) # scan initial value for computing p(X_{0:n} | Y_{0:m}, Z_{1:n}) scan_init = { "state_filt": (mean_state_init, var_state_init), "i": i } if key is not None: keys = jax.random.split(key, num=n_steps) else: keys = jnp.zeros(n_steps) filter_kwargs = { "t": jnp.arange(n_steps), "key": keys } _, scan_out = jax.lax.scan(scan_fun, scan_init, filter_kwargs) # append initial values to front scan_out["state_filt"] = ( jnp.concatenate([mean_state_init[None], scan_out["state_filt"][0]]), jnp.concatenate([var_state_init[None], scan_out["state_filt"][1]]) ) scan_out["state_pred"] = ( jnp.concatenate([mean_state_init[None], scan_out["state_pred"][0]]), jnp.concatenate([var_state_init[None], scan_out["state_pred"][1]]) ) return scan_out
[docs] def solve_mv(key, ode_fun, ode_weight, ode_init, t_min, t_max, n_steps, interrogate, prior_pars, obs_data, obs_times, obs_weight, obs_var, kalman_type="standard", **params): r""" DALTON algorithm to compute the mean and variance of :math:`p(X_{0:N} \mid Y_{0:M}, Z_{1:N})` assuming Gaussian observations. Same arguments as :func:`dalton`. Returns: (tuple): - **mean_state_smooth** (ndarray(n_steps+1, n_block, n_bstate)): Posterior mean of the solution process :math:`X_t` at times :math:`t \in [a, b]`. - **var_state_smooth** (ndarray(n_steps+1, n_block, n_bstate, n_bstate)): Posterior variance of the solution process at times :math:`t \in [a, b]`. """ prior_weight, prior_var = prior_pars n_block, n_bstate, _ = prior_weight.shape # standard or square-root filter if kalman_type == "standard": kalman_funs = standard elif kalman_type == "square-root": kalman_funs = square_root else: raise NotImplementedError # forward pass filt_out = _solve_filter( key=key, ode_fun=ode_fun, ode_weight=ode_weight, ode_init=ode_init, t_min=t_min, t_max=t_max, n_steps=n_steps, interrogate=interrogate, prior_weight=prior_weight, prior_var=prior_var, obs_data=obs_data, obs_times=obs_times, obs_weight=obs_weight, obs_var=obs_var, kalman_funs=kalman_funs, **params ) mean_state_pred, var_state_pred = filt_out["state_pred"] mean_state_filt, var_state_filt = filt_out["state_filt"] # backward pass def scan_fun(state_next, smooth_kwargs): mean_state_filt = smooth_kwargs['mean_state_filt'] var_state_filt = smooth_kwargs['var_state_filt'] mean_state_pred = smooth_kwargs['mean_state_pred'] var_state_pred = smooth_kwargs['var_state_pred'] mean_state_curr, var_state_curr = jax.vmap(kalman_funs.smooth_mv)( mean_state_next=state_next["mean"], var_state_next=state_next["var"], wgt_state=prior_weight, mean_state_filt=mean_state_filt, var_state_filt=var_state_filt, mean_state_pred=mean_state_pred, var_state_pred=var_state_pred, var_State=prior_var ) state_curr = { "mean": mean_state_curr, "var": var_state_curr } return state_curr, state_curr # initialize scan_init = { "mean": mean_state_filt[n_steps], "var": var_state_filt[n_steps] } # scan arguments scan_kwargs = { 'mean_state_filt': mean_state_filt[1:n_steps], 'var_state_filt': var_state_filt[1:n_steps], 'mean_state_pred': mean_state_pred[2:n_steps+1], 'var_state_pred': var_state_pred[2:n_steps+1] } # Note: initial value x0 is assumed to be known, so no need to smooth it _, scan_out = jax.lax.scan(scan_fun, scan_init, scan_kwargs, reverse=True) # append initial values to front and back mean_state_smooth = jnp.concatenate( [ode_init[None], scan_out["mean"], scan_init["mean"][None]] ) var_state_smooth = jnp.concatenate( [jnp.zeros((n_block, n_bstate, n_bstate))[None], scan_out["var"], scan_init["var"][None]] ) return mean_state_smooth, var_state_smooth
[docs] def solve_sim(key, ode_fun, ode_weight, ode_init, t_min, t_max, n_steps, interrogate, prior_pars, obs_data, obs_times, obs_weight, obs_var, kalman_type="standard", **params): r""" DALTON algorithm to sample from :math:`p(X_{0:N} \mid Y_{0:M}, Z_{1:N})` assuming Gaussian observations. Same arguments as :func:`dalton`. Returns: (tuple): - **mean_state_smooth** (ndarray(n_steps+1, n_block, n_bstate)): Posterior mean of the solution process :math:`X_t` at times :math:`t \in [a, b]`. - **var_state_smooth** (ndarray(n_steps+1, n_block, n_bstate, n_bstate)): Posterior variance of the solution process at times :math:`t \in [a, b]`. """ prior_weight, prior_var = prior_pars # standard or square-root filter if kalman_type == "standard": kalman_funs = standard elif kalman_type == "square-root": kalman_funs = square_root else: raise NotImplementedError key, *subkeys = jax.random.split(key, num=n_steps+1) # forward pass filt_out = _solve_filter( key=key, ode_fun=ode_fun, ode_weight=ode_weight, ode_init=ode_init, t_min=t_min, t_max=t_max, n_steps=n_steps, interrogate=interrogate, prior_weight=prior_weight, prior_var=prior_var, obs_data=obs_data, obs_times=obs_times, obs_weight=obs_weight, obs_var=obs_var, kalman_funs=kalman_funs, **params ) mean_state_pred, var_state_pred = filt_out["state_pred"] mean_state_filt, var_state_filt = filt_out["state_filt"] # backward pass def scan_fun(x_state_next, smooth_kwargs): mean_state_filt = smooth_kwargs['mean_state_filt'] var_state_filt = smooth_kwargs['var_state_filt'] mean_state_pred = smooth_kwargs['mean_state_pred'] var_state_pred = smooth_kwargs['var_state_pred'] key = smooth_kwargs['key'] mean_state_sim, var_state_sim = jax.vmap(kalman_funs.smooth_sim)( x_state_next=x_state_next, wgt_state=prior_weight, mean_state_filt=mean_state_filt, var_state_filt=var_state_filt, mean_state_pred=mean_state_pred, var_state_pred=var_state_pred, var_state=prior_var ) x_state_curr = jax.random.multivariate_normal(key, mean_state_sim, var_state_sim, method='svd') return x_state_curr, x_state_curr # initialize scan_init = jax.random.multivariate_normal( subkeys[n_steps-1], mean_state_filt[n_steps], var_state_filt[n_steps], method='svd') # scan arguments scan_kwargs = { 'mean_state_filt': mean_state_filt[1:n_steps], 'var_state_filt': var_state_filt[1:n_steps], 'mean_state_pred': mean_state_pred[2:n_steps+1], 'var_state_pred': var_state_pred[2:n_steps+1], 'key': jnp.array(subkeys[:n_steps-1]) } # Note: initial value x0 is assumed to be known, so no need to smooth it _, scan_out = jax.lax.scan(scan_fun, scan_init, scan_kwargs, reverse=True) # append initial values to front and back x_state_smooth = jnp.concatenate( [ode_init[None], scan_out, scan_init[None]] ) return x_state_smooth
# --- non-Gaussian loglikelihood ------------------------------------------------- def _solve_filter_nn(key, ode_fun, ode_weight, ode_init, t_min, t_max, n_steps, interrogate, prior_weight, prior_var, obs_data, obs_times, obs_loglik_i, kalman_funs, **params): r""" Forward pass of the DALTON algorithm using non-Gaussian observations. Same arguments as :func:`daltonng`. Returns: (tuple): - **mean_state_pred** (ndarray(n_steps+1, n_block, n_bstate)): Mean estimate for state at time t given observations from times [a...t-1] for :math:`t \in [a, b]`. - **var_state_pred** (ndarray(n_steps+1, n_block, n_bstate, n_bstate)): Variance estimate for state at time t given observations from times [a...t-1] for :math:`t \in [a, b]`. - **mean_state_filt** (ndarray(n_steps+1, n_block, n_bstate)): Mean estimate for state at time t given observations from times [a...t] for :math:`t \in [a, b]`. - **var_state_filt** (ndarray(n_steps+1, n_block, n_bstate, n_bstate)): Variance estimate for state at time t given observations from times [a...t] for :math:`t \in [a, b]`. """ # Dimensions of block, state and measure variables n_block, n_bmeas, n_bstate = ode_weight.shape # Dimension of observation # n_bobs = obs_data.shape[2] # insert observations on solver time grid sim_times = jnp.linspace(t_min, t_max, n_steps + 1) obs_ind = jnp.searchsorted(sim_times, obs_times) # arguments for kalman_filter and kalman_smooth x_meas = jnp.zeros((n_block, n_bmeas)) obs_mean = jnp.zeros((n_block, n_bstate)) mean_state = jnp.zeros((n_block, n_bstate)) mean_state_init = ode_init var_state_init = jnp.zeros((n_block, n_bstate, n_bstate)) # compute p(X_{1:n} | Z_{1:n}, \hat Y_{0:m}) def scan_fun(carry, filter_kwargs): mean_state_filt, var_state_filt = carry["state_filt"] i = carry["i"] t = filter_kwargs["t"] key = filter_kwargs["key"] ode_time = t_min + (t_max-t_min)*(t+1)/n_steps # kalman predict mean_state_pred, var_state_pred = jax.vmap(kalman_funs.predict)( mean_state_past=mean_state_filt, var_state_past=var_state_filt, mean_state=mean_state, wgt_state=prior_weight, var_state=prior_var ) # compute meas parameters wgt_meas, mean_meas, var_meas = interrogate( key=key, ode_fun=ode_fun, ode_weight=ode_weight, t=ode_time, mean_state_pred=mean_state_pred, var_state_pred=var_state_pred, **params ) W_meas = ode_weight + wgt_meas # both z and y are observed def zy_update(): # transform to yhat obs_grad = jax.jacrev(obs_loglik_i, argnums=1)(obs_data[i], mean_state_pred, i, **params) obs_hes = jax.jacfwd(jax.jacrev(obs_loglik_i, argnums=1), argnums=1)( obs_data[i], mean_state_pred, i, **params ) obs_var = jax.vmap(lambda b: -jnp.linalg.pinv(obs_hes[b, :, b]))(jnp.arange(n_block)) obs_weight = jnp.where(obs_var != 0, 1, 0) obs_hat = jax.vmap(lambda b: obs_weight[i].dot(mean_state_pred[b]) + obs_var[b].dot(obs_grad[b]) )(jnp.arange(n_block)) # Cmu = jax.vmap(lambda b: wgt_curr[b].dot(mean_state_pred[b]))(jnp.arange(n_block)) # gpmu = jax.jacfwd(fun_obs)(Cmu, y_curr, theta, i) # gppmu = jax.jacfwd(jax.jacrev(fun_obs))(Cmu, y_curr, theta, i) # var_obs = jax.vmap(lambda b: -jnp.linalg.pinv(gppmu[b, :, b]))(jnp.arange(n_block)) # y_new = jax.vmap(lambda b: Cmu[b] + var_obs[b].dot(gpmu[b]))(jnp.arange(n_block)) # stack measure and observation variables wgt_meas_obs = jnp.concatenate([W_meas, obs_weight], axis=1) mean_meas_obs = jnp.concatenate([mean_meas, obs_mean], axis=1) var_meas_obs = jax.vmap(lambda b: jsp.linalg.block_diag(var_meas[b], obs_var[b]))(jnp.arange(n_block)) x_meas_obs = jnp.concatenate([x_meas, obs_hat], axis=1) # kalman update mean_state_next, var_state_next = jax.vmap(kalman_funs.update)( mean_state_pred=mean_state_pred, var_state_pred=var_state_pred, x_meas=x_meas_obs, mean_meas=mean_meas_obs, wgt_meas=wgt_meas_obs, var_meas=var_meas_obs ) return mean_state_next, var_state_next, i+1 # only z is observed def z_update(): # kalman update mean_state_next, var_state_next = jax.vmap(kalman_funs.update)( mean_state_pred=mean_state_pred, var_state_pred=var_state_pred, x_meas=x_meas, mean_meas=mean_meas, wgt_meas=W_meas, var_meas=var_meas ) return mean_state_next, var_state_next, i mean_state_next, var_state_next, i = jax.lax.cond(t+1 == obs_ind[i], zy_update, z_update) # output carry = { "state_filt": (mean_state_next, var_state_next), "i": i } stack = { "state_filt": (mean_state_next, var_state_next), "state_pred": (mean_state_pred, var_state_pred) } return carry, stack # check if observations start at 0 i = jax.lax.cond(obs_ind[0] == 0, lambda: 1, lambda: 0) # scan initial value for computing p(X_{0:n} | \hat Y_{0:m}, Z_{1:n}) scan_init = { "state_filt": (mean_state_init, var_state_init), "i": i } if key is not None: keys = jax.random.split(key, num=n_steps) else: keys = jnp.zeros(n_steps) filter_kwargs = { "t": jnp.arange(n_steps), "key": keys } _, scan_out = jax.lax.scan(scan_fun, scan_init, filter_kwargs) # append initial values to front scan_out["state_filt"] = ( jnp.concatenate([mean_state_init[None], scan_out["state_filt"][0]]), jnp.concatenate([var_state_init[None], scan_out["state_filt"][1]]) ) scan_out["state_pred"] = ( jnp.concatenate([mean_state_init[None], scan_out["state_pred"][0]]), jnp.concatenate([var_state_init[None], scan_out["state_pred"][1]]) ) return scan_out def _logx_yhat(mean_state_filt, var_state_filt, mean_state_pred, var_state_pred, prior_weight, prior_var, kalman_funs): r""" Compute the loglikelihood of :math:`p(X_{0:N} \mid \hat Y_{0:M}, Z_{1:N})`. Args: mean_state_pred (ndarray(n_steps+1, n_block, n_bstate)): Mean estimate for state at time n given observations from times [0...n]; denoted by :math:`\mu_{n|n-1}`. var_state_pred (ndarray(n_steps+1, n_block, n_bstate, n_bstate)): Covariance of estimate for state at time n given observations from times [0...n-1]; denoted by :math:`\Sigma_{n|n-1}`. mean_state_filt (ndarray(n_steps+1, n_block, n_bstate)): Mean estimate for state at time n given observations from times [0...n]; denoted by :math:`\mu_{n|n}`. var_state_filt (ndarray(n_steps+1, n_block, n_bstate, n_bstate)): Covariance of estimate for state at time n given observations from times [0...n]; denoted by :math:`\Sigma_{n|n}`. prior_weight (ndarray(n_block, n_bstate, n_bstate)): Weight matrix defining the solution prior; :math:`Q`. prior_var (ndarray(n_block, n_bstate, n_bstate)): Variance matrix defining the solution prior; :math:`R`. kalman_funs (object): An object that contains the Kalman filtering functions: predict, update and smooth. Returns: (tuple): - **mean_state_smooth** (ndarray(n_steps+1, n_block, n_bstate)): Posterior mean of the solution process :math:`p(X_{0:N} \mid \hat Y_{0:M}, Z_{1:N})`. - **logx_yhat** (float): Loglikelihood of :math:`p(X_{0:N} \mid \hat Y_{0:M}, Z_{1:N})`. """ # dimensions n_tot, n_block, _ = mean_state_filt.shape n_steps = n_tot - 1 # backward pass def scan_fun(state_next, smooth_kwargs): mean_state_filt = smooth_kwargs['mean_state_filt'] var_state_filt = smooth_kwargs['var_state_filt'] mean_state_pred = smooth_kwargs['mean_state_pred'] var_state_pred = smooth_kwargs['var_state_pred'] logx_yhat = state_next["logx_yhat"] mean_state_curr, var_state_curr = jax.vmap(kalman_funs.smooth_mv)( mean_state_next=state_next["mean"], var_state_next=state_next["var"], mean_state_filt=mean_state_filt, var_state_filt=var_state_filt, mean_state_pred=mean_state_pred, var_state_pred=var_state_pred, wgt_state=prior_weight, var_state=prior_var ) mean_state_sim, var_state_sim = jax.vmap(kalman_funs.smooth_sim)( x_state_next=state_next["mean"], mean_state_filt=mean_state_filt, var_state_filt=var_state_filt, mean_state_pred=mean_state_pred, var_state_pred=var_state_pred, wgt_state=prior_weight, var_state=prior_var ) logx_yhat += jnp.sum( jax.vmap(multivariate_normal_logpdf)(mean_state_curr, mean=mean_state_sim, cov=var_state_sim) ) carry = { "mean": mean_state_curr, "var": var_state_curr, "logx_yhat": logx_yhat } return carry, carry # compute log(mu_{N|N}) at the last filtering step logx_yhatN = jnp.sum( jax.vmap(multivariate_normal_logpdf)(mean_state_filt[n_steps], mean=mean_state_filt[n_steps], cov=var_state_filt[n_steps]) ) # initialize scan_init = { "mean": mean_state_filt[n_steps], "var": var_state_filt[n_steps], "logx_yhat": logx_yhatN } # scan arguments scan_kwargs = { 'mean_state_filt': mean_state_filt[1:n_steps], 'var_state_filt': var_state_filt[1:n_steps], 'mean_state_pred': mean_state_pred[2:n_steps+1], 'var_state_pred': var_state_pred[2:n_steps+1] } # Note: initial value x0 is assumed to be known, so no need to smooth it last_scan, scan_out = jax.lax.scan(scan_fun, scan_init, scan_kwargs,reverse=True) # append initial values to front and terminal value to the back scan_out["mean"] = jnp.concatenate( [mean_state_filt[0][None], scan_out["mean"], scan_init["mean"][None]] ) return scan_out["mean"], last_scan["logx_yhat"] def _logx_z(uncond_mean, mean_state_filt, var_state_filt, mean_state_pred, var_state_pred, prior_weight, prior_var, kalman_funs): r""" Compute the loglikelihood of :math:`p(X_{0:N} \mid Z_{1:N})`. Args: uncond_mean (ndarray(n_steps+1, n_block, n_bstate)): Unconditional mean computed from :math:`p(X_{0:N} \mid \hat Y_{0:M}, Z_{1:N})`. mean_state_pred (ndarray(n_steps+1, n_block, n_bstate)): Mean estimate for state at time n given observations from times [0...n]; denoted by :math:`\mu_{n|n-1}`. var_state_pred (ndarray(n_steps+1, n_block, n_bstate, n_bstate)): Covariance of estimate for state at time n given observations from times [0...n-1]; denoted by :math:`\Sigma_{n|n-1}`. mean_state_filt (ndarray(n_steps+1, n_block, n_bstate)): Mean estimate for state at time n given observations from times [0...n]; denoted by :math:`\mu_{n|n}`. var_state_filt (ndarray(n_steps+1, n_block, n_bstate, n_bstate)): Covariance of estimate for state at time n given observations from times [0...n]; denoted by :math:`\Sigma_{n|n}`. prior_weight (ndarray(n_block, n_bstate, n_bstate)): Weight matrix defining the solution prior; :math:`Q`. prior_var (ndarray(n_block, n_bstate, n_bstate)): Variance matrix defining the solution prior; :math:`R`. kalman_funs (object): An object that contains the Kalman filtering functions: predict, update and smooth. Return: (float): Loglikelihood of :math:`p(X_{0:N} \mid Z_{1:N})`. """ # dimensions n_tot, n_block, _ = mean_state_filt.shape n_steps = n_tot - 1 # backward pass def scan_fun(logx_z, smooth_kwargs): mean_state_filt = smooth_kwargs['mean_state_filt'] var_state_filt = smooth_kwargs['var_state_filt'] mean_state_pred = smooth_kwargs['mean_state_pred'] var_state_pred = smooth_kwargs['var_state_pred'] uncond_next = smooth_kwargs['uncond_next'] uncond_curr = smooth_kwargs['uncond_curr'] mean_state_sim, var_state_sim = jax.vmap(kalman_funs.smooth_sim)( x_state_next=uncond_next, mean_state_filt=mean_state_filt, var_state_filt=var_state_filt, mean_state_pred=mean_state_pred, var_state_pred=var_state_pred, wgt_state=prior_weight, var_state=prior_var ) logx_z += jnp.sum( jax.vmap(multivariate_normal_logpdf)(uncond_curr, mean=mean_state_sim, cov=var_state_sim) ) return logx_z, logx_z # compute log(mu_{N|N}) at the last filtering step logx_zN = jnp.sum( jax.vmap(multivariate_normal_logpdf)(uncond_mean[n_steps], mean=mean_state_filt[n_steps], cov=var_state_filt[n_steps]) ) # scan arguments scan_kwargs = { 'mean_state_filt': mean_state_filt[1:n_steps], 'var_state_filt': var_state_filt[1:n_steps], 'mean_state_pred': mean_state_pred[2:n_steps+1], 'var_state_pred': var_state_pred[2:n_steps+1], 'uncond_next': uncond_mean[2:n_steps+1], 'uncond_curr': uncond_mean[1:n_steps] } # Note: initial value x0 is assumed to be known, so no need to smooth it scan_out, _ = jax.lax.scan(scan_fun, logx_zN, scan_kwargs, reverse=True) return scan_out
[docs] def daltonng(key, ode_fun, ode_weight, ode_init, t_min, t_max, n_steps, interrogate, prior_pars, obs_data, obs_times, obs_loglik_i, kalman_type="standard", **params): r""" Compute marginal loglikelihood of DALTON algorithm for non-Gaussian observations; :math:`p(\hat Y_{0:M} \mid Z_{1:N})`. Args: key (PRNGKey): PRNG key. ode_fun (Callable): Higher order ODE function :math:`W X_t = F(X_t, t)` taking arguments :math:`X` and :math:`t`. ode_weight (ndarray(n_block, n_bmeas, n_bstate)): Weight matrix defining the measure prior; :math:`W`. ode_init (ndarray(n_block, n_bstate)): Initial value of the state variable :math:`X_t` at time :math:`t = a`. t_min (float): First time point of the time interval to be evaluated; :math:`a`. t_max (float): Last time point of the time interval to be evaluated; :math:`b`. n_steps (int): Number of discretization points (:math:`N`) of the time interval that is evaluated, such that discretization timestep is :math:`dt = (b-a)/N`. interrogate (Callable): Function defining the interrogation method. prior_pars (tuple): A tuple containing the weight matrix and the variance matrix defining the solution prior; :math:`Q, R`. obs_data (ndarray(n_obs, n_blocks, n_bobs)): Observed data; :math:`Y_{0:M}`. obs_times (ndarray(n_obs)): Observation time; :math:`0, \ldots, M`. obs_loglik_i (Callable): Loglikelihood function for each observation. kalman_type (str): Determine which type of Kalman (standard, square-root) to use. params (kwargs): Optional model parameters. Returns: (float): Loglikelihood of :math:`p(\hat Y_{0:M} \mid Z_{1:N})`. """ n_obs = obs_data.shape[0] # standard or square-root filter if kalman_type == "standard": kalman_funs = standard elif kalman_type == "square-root": kalman_funs = square_root else: raise NotImplementedError prior_weight, prior_var = prior_pars # forward pass filt_out = _solve_filter_nn( key=key, ode_fun=ode_fun, ode_weight=ode_weight, ode_init=ode_init, t_min=t_min, t_max=t_max, n_steps=n_steps, interrogate=interrogate, prior_weight=prior_weight, prior_var=prior_var, obs_data=obs_data, obs_times=obs_times, obs_loglik_i=obs_loglik_i, kalman_funs=kalman_funs, **params ) mean_state_pred, var_state_pred = filt_out["state_pred"] mean_state_filt, var_state_filt = filt_out["state_filt"] # logp(X_{0:N} | \hat Y_{0:M}, Z_{1:N}) mean_state_smooth, logx_yhat = _logx_yhat( mean_state_filt=mean_state_filt, var_state_filt=var_state_filt, mean_state_pred=mean_state_pred, var_state_pred=var_state_pred, prior_weight=prior_weight, prior_var=prior_var, kalman_funs=kalman_funs ) # logp(Y_{0:M} | X_{0:M}) sim_times = jnp.linspace(t_min, t_max, n_steps+1) obs_ind = jnp.searchsorted(sim_times, obs_times) def vmap_fun(i): return obs_loglik_i(obs_data[i], mean_state_smooth[obs_ind[i]], i, **params) logy_x = jnp.sum(jax.vmap(vmap_fun)(jnp.arange(n_obs))) # logp(X_{0:N} | Z_{1:N}) # first do forward pass without obs filt_out = _solve_filter_ode( key=key, ode_fun=ode_fun, ode_weight=ode_weight, ode_init=ode_init, t_min=t_min, t_max=t_max, n_steps=n_steps, interrogate=interrogate, prior_weight=prior_weight, prior_var=prior_var, kalman_funs=kalman_funs, **params ) mean_state_pred, var_state_pred = filt_out["state_pred"] mean_state_filt, var_state_filt = filt_out["state_filt"] logx_z = _logx_z( uncond_mean=mean_state_smooth, mean_state_filt=mean_state_filt, var_state_filt=var_state_filt, mean_state_pred=mean_state_pred, var_state_pred=var_state_pred, prior_weight=prior_weight, prior_var=prior_var, kalman_funs=kalman_funs ) return logy_x + logx_z - logx_yhat
# --- non-Gaussian ODE solver -------------------------------------------------
[docs] def solve_mv_nn(key, ode_fun, ode_weight, ode_init, t_min, t_max, n_steps, interrogate, prior_pars, obs_data, obs_times, obs_loglik_i, kalman_type="standard", **params): r""" DALTON algorithm to compute the mean and variance of :math:`p(X_{0:N} \mid \hat Y_{0:M}, Z_{1:N})` assuming non-Gaussian observations. Same arguments as :func:`daltonng`. Returns: (tuple): - **mean_state_smooth** (ndarray(n_steps+1, n_block, n_bstate)): Posterior mean of the solution process :math:`X_t` at times :math:`t \in [a, b]`. - **var_state_smooth** (ndarray(n_steps+1, n_block, n_bstate, n_bstate)): Posterior variance of the solution process at times :math:`t \in [a, b]`. """ prior_weight, prior_var = prior_pars n_block, n_bstate, _ = prior_weight.shape # standard or square-root filter if kalman_type == "standard": kalman_funs = standard elif kalman_type == "square-root": kalman_funs = square_root else: raise NotImplementedError # forward pass filt_out = _solve_filter_nn( key=key, ode_fun=ode_fun, ode_weight=ode_weight, ode_init=ode_init, t_min=t_min, t_max=t_max, n_steps=n_steps, interrogate=interrogate, prior_weight=prior_weight, prior_var=prior_var, obs_data=obs_data, obs_times=obs_times, obs_loglik_i=obs_loglik_i, kalman_funs=kalman_funs, **params ) mean_state_pred, var_state_pred = filt_out["state_pred"] mean_state_filt, var_state_filt = filt_out["state_filt"] # backward pass def scan_fun(state_next, smooth_kwargs): mean_state_filt = smooth_kwargs['mean_state_filt'] var_state_filt = smooth_kwargs['var_state_filt'] mean_state_pred = smooth_kwargs['mean_state_pred'] var_state_pred = smooth_kwargs['var_state_pred'] mean_state_curr, var_state_curr = jax.vmap(kalman_funs.smooth_mv)( mean_state_next=state_next["mean"], var_state_next=state_next["var"], wgt_state=prior_weight, mean_state_filt=mean_state_filt, var_state_filt=var_state_filt, mean_state_pred=mean_state_pred, var_state_pred=var_state_pred, var_state=prior_var ) state_curr = { "mean": mean_state_curr, "var": var_state_curr } return state_curr, state_curr # initialize scan_init = { "mean": mean_state_filt[n_steps], "var": var_state_filt[n_steps] } # scan arguments scan_kwargs = { 'mean_state_filt': mean_state_filt[1:n_steps], 'var_state_filt': var_state_filt[1:n_steps], 'mean_state_pred': mean_state_pred[2:n_steps+1], 'var_state_pred': var_state_pred[2:n_steps+1] } # Note: initial value x0 is assumed to be known, so no need to smooth it _, scan_out = jax.lax.scan(scan_fun, scan_init, scan_kwargs, reverse=True) # append initial values to front and back mean_state_smooth = jnp.concatenate( [ode_init[None], scan_out["mean"], scan_init["mean"][None]] ) var_state_smooth = jnp.concatenate( [jnp.zeros((n_block, n_bstate, n_bstate))[None], scan_out["var"], scan_init["var"][None]] ) return mean_state_smooth, var_state_smooth