r"""
This module implements the Fenrir algorithm as described in Tronarp et al 2022 for computing the approximate likelihood of :math:`p(Y_{0:M} \mid Z_{1:N})`.
The forward pass model is
.. math::
x_0 = v
X_n = c_n + Q_n X_{n-1} + R_n^{1/2} \epsilon_n
Z_n = W_n X_n - f(X_n, t_n) + V_n^{1/2} \eta_n.
We assume that :math:`c_n = 0, Q_n = Q, R_n = R`, and :math:`W_n = W` for all :math:`n`. Using the Kalman filtering recursions, the above model can be simulated via the reverse pass model
.. math::
X_N \sim \operatorname{Normal}(b_N, C_N)
X_n = A_n X_{n+1} + b_n + C_n^{1/2} \epsilon_n.
Fenrir combines the observations
.. math::
Y_m = D_m X_m + \Omega^{1/2}_m \eta_m,
with the reverse pass model to condition on data. Here :math:`\epsilon_n, \eta_m` are standard normals.
"""
import jax
import jax.numpy as jnp
from rodeo.kalmantv import standard
from rodeo.kalmantv import square_root
from rodeo.solve import _solve_filter
from rodeo.utils import multivariate_normal_logpdf
# --- helper functions --------------------------------------------------------
def _forecast_update(mean_state_pred, var_state_pred,
x_meas, mean_meas,
wgt_meas, var_meas,
kalman_funs):
r"""
Perform one update step of the Kalman filter and forecast.
Args:
mean_state_pred (ndarray(n_block, n_bstate)): Mean estimate for state at time n given observations from times [0...n-1]; denoted by :math:`\mu_{n|n-1}`.
var_state_pred (ndarray(n_block, n_bstate, n_sbtate)): Covariance of estimate for state at time n given observations from times [0...n-1]; denoted by :math:`\Sigma_{n|n-1}`.
x_meas (ndarray(n_block, n_bmeas)): Interrogated measure vector from `x_state`; :math:`y_n`.
mean_meas (ndarray(n_block, n_bmeas)): Transition offsets defining the measure prior.
wgt_meas (ndarray(n_block, n_bmeas, n_bstate)): Transition matrix defining the measure prior.
var_meas (ndarray(n_block, n_bmeas, n_bmeas)): Variance matrix defining the measure prior.
kalman_funs (object): An object that contains the Kalman filtering functions: predict, update and smooth.
Returns:
(tuple):
- **logdens** (float): The loglikelihood for the observations.
- **mean_state_filt** (ndarray(n_block, n_bstate)): Mean estimate for state at time n given observations from times [0...n]; denoted by :math:`\mu_{n|n}`.
- **var_state_filt** (ndarray(n_block, n_bstate, n_bstate)): Covariance of estimate for state at time n given observations from times [0...n]; denoted by :math:`\Sigma_{n|n}`.
"""
# kalman forecast
mean_state_fore, var_state_fore = kalman_funs.forecast(
mean_state_pred=mean_state_pred,
var_state_pred=var_state_pred,
mean_meas=mean_meas,
wgt_meas=wgt_meas,
var_meas=var_meas
)
logdens = multivariate_normal_logpdf(
x_meas, mean=mean_state_fore, cov=var_state_fore)
# kalman update
mean_state_filt, var_state_filt = kalman_funs.update(
mean_state_pred=mean_state_pred,
var_state_pred=var_state_pred,
x_meas=x_meas,
mean_meas=mean_meas,
wgt_meas=wgt_meas,
var_meas=var_meas
)
return logdens, mean_state_filt, var_state_filt
# --- loglikelihood -----------------------------------------------------------
def _backward(mean_state_filt, var_state_filt,
mean_state_pred, var_state_pred,
prior_weight, prior_var,
t_min, t_max, n_steps,
obs_data, obs_times,
obs_weight, obs_var,
kalman_funs):
r"""
Compute the backward Markov chain parameters and forward pass but backwards in time.
Args:
mean_state_pred (ndarray(n_steps+1, n_block, n_bstate)): Mean estimate for state at time n given observations from times [0...n]; denoted by :math:`\mu_{n|n-1}`.
var_state_pred (ndarray(n_steps+1, n_block, n_bstate, n_bstate)): Covariance of estimate for state at time n given observations from times [0...n-1]; denoted by :math:`\Sigma_{n|n-1}`.
mean_state_filt (ndarray(n_steps+1, n_block, n_bstate)): Mean estimate for state at time n given observations from times [0...n]; denoted by :math:`\mu_{n|n}`.
var_state_filt (ndarray(n_steps+1, n_block, n_bstate, n_bstate)): Covariance of estimate for state at time n given observations from times [0...n]; denoted by :math:`\Sigma_{n|n}`.
prior_weight (ndarray(n_block, n_bstate, n_bstate)): Weight matrix defining the solution prior; :math:`Q`.
prior_var (ndarray(n_block, n_bstate, n_bstate)): Variance matrix defining the solution prior; :math:`R`.
t_min (float): First time point of the time interval to be evaluated; :math:`t_0`.
t_max (float): Last time point of the time interval to be evaluated; :math:`t_N`.
n_steps (int): Number of discretization points (:math:`N`) of the time interval that is evaluated, such that discretization timestep is :math:`dt = (b-a)/N`.
obs_data (ndarray(n_obs, n_blocks, n_bobs)): Observed data; :math:`y_{0:M}`.
obs_times (ndarray(n_obs)): Observation time; :math:`0, \ldots, M`.
obs_weight (ndarray(n_obs, n_blocks, n_bobs, n_bstate)): Weight matrix in the observation model; :math:`D_{0:M}`.
obs_var (ndarry(n_obs, n_blocks, n_bobs, n_bobs)): Variance matrix in the observation model; :math:`\Omega_{0:M}`
kalman_funs (object): An object that contains the Kalman filtering functions: predict, update and smooth.
Returns:
(float) : The logdensity of :math:`p(y_{0:M} \mid Z_{1:N})`.
"""
# get dimensions
n_obs, n_block, n_bobs, n_bstate = obs_weight.shape
# insert observations on solver time grid
sim_times = jnp.linspace(t_min, t_max, n_steps + 1)
obs_ind = jnp.searchsorted(sim_times, obs_times)
# offset of obs is assumed to be 0
obs_mean = jnp.zeros((n_block, n_bobs))
# forecast function without kalman_funs
forecast_update = lambda mean_state_pred, var_state_pred,\
x_meas, mean_meas, wgt_meas, var_meas\
: _forecast_update(mean_state_pred, var_state_pred,
x_meas, mean_meas,
wgt_meas, var_meas,
kalman_funs)
def scan_fun(carry, forward_states):
# Kalman filter backwards in time
bmean_state_filt, bvar_state_filt = carry["state_filt"]
# Kalman filter estimates from forward
mean_state_filt, var_state_filt = forward_states["state_filt"]
mean_state_pred, var_state_pred = forward_states["state_pred"]
logdens = carry["logdens"]
i = carry["i"]
t = forward_states["t"] # t_n
# get Markov params
wgt_state_back, mean_state_back, var_state_back = jax.vmap(kalman_funs.smooth_cond)(
mean_state_filt=mean_state_filt,
var_state_filt=var_state_filt,
mean_state_pred=mean_state_pred,
var_state_pred=var_state_pred,
wgt_state=prior_weight,
var_state=prior_var
)
# kalman predict
bmean_state_pred, bvar_state_pred = jax.vmap(kalman_funs.predict)(
mean_state_past=bmean_state_filt,
var_state_past=bvar_state_filt,
mean_state=mean_state_back,
wgt_state=wgt_state_back,
var_state=var_state_back
)
# not t time point of observation
def _no_obs():
return bmean_state_pred, bvar_state_pred, 0.0, i
# at time point of observation
def _obs():
# kalman forecast and update
logp, bmean_state_next, bvar_state_next = jax.vmap(forecast_update)(
mean_state_pred=bmean_state_pred,
var_state_pred=bvar_state_pred,
x_meas=obs_data[i],
mean_meas=obs_mean,
wgt_meas=obs_weight[i],
var_meas=obs_var[i]
)
return bmean_state_next, bvar_state_next, jnp.sum(logp), i-1
bmean_state_filt, bvar_state_filt, logp, i = jax.lax.cond(
obs_ind[i] == t, _obs, _no_obs)
logdens += logp
# output
carry = {
"state_filt": (bmean_state_filt, bvar_state_filt),
"logdens": logdens,
"i": i
}
stack = {
"state_pred": (bmean_state_pred, bvar_state_pred),
"state_filt": (bmean_state_filt, bvar_state_filt),
"wgt_state": wgt_state_back,
"var_state": var_state_back
}
return carry, stack
# terminal point update
mean_state_term = mean_state_filt[n_steps]
var_state_term = var_state_filt[n_steps]
logdens = 0.0
i = n_obs - 1
# no observations
def _no_obs():
# no need to update
return mean_state_term, var_state_term, 0.0, i
# observation
def _obs():
# kalman forecast and update
logp, bmean_state_next, bvar_state_next = jax.vmap(forecast_update)(
mean_state_pred=mean_state_term,
var_state_pred=var_state_term,
x_meas=obs_data[i],
mean_meas=obs_mean,
wgt_meas=obs_weight[i],
var_meas=obs_var[i]
)
return bmean_state_next, bvar_state_next, jnp.sum(logp), i-1
bmean_state_filt, bvar_state_filt, logp, i = jax.lax.cond(
obs_ind[i] >= n_steps, _obs, _no_obs)
logdens += logp
# start at N
scan_init = {
"state_filt": (bmean_state_filt, bvar_state_filt),
"logdens": logdens,
"i": i
}
forward_states_init = {
"state_pred": (mean_state_pred[1:n_steps+1], var_state_pred[1:n_steps+1]),
"state_filt": (mean_state_filt[:n_steps], var_state_filt[:n_steps]),
"t": jnp.arange(n_steps)
}
scan_out, scan_out2 = jax.lax.scan(
scan_fun, scan_init, forward_states_init, reverse=True)
# append initial values to back
mean_scan_pred, var_scan_pred = scan_out2["state_pred"]
mean_scan_filt, var_scan_filt = scan_out2["state_filt"]
mean_state_pred = jnp.concatenate(
[mean_scan_pred, mean_state_term[None]]
)
var_state_pred = jnp.concatenate(
[var_scan_pred, var_state_term[None]]
)
mean_state_filt = jnp.concatenate(
[mean_scan_filt, bmean_state_filt[None]]
)
var_state_filt = jnp.concatenate(
[var_scan_filt, bvar_state_filt[None]]
)
# repack
scan_out2 = {
"state_pred": (mean_state_pred, var_state_pred),
"state_filt": (mean_state_filt, var_state_filt),
"wgt_state": scan_out2["wgt_state"],
"var_state": scan_out2["var_state"]
}
return scan_out["logdens"], scan_out2
[docs]
def fenrir(key, ode_fun, ode_weight, ode_init,
t_min, t_max, n_steps,
interrogate,
prior_pars,
obs_data, obs_times, obs_weight, obs_var,
kalman_type="standard", **params):
r"""
Fenrir algorithm to compute the approximate loglikelihood of :math:`p(Y_{0:M} \mid Z_{1:N})`.
Args:
key (PRNGKey): PRNG key.
ode_fun (Callable): Higher order ODE function :math:`W X_t = F(X_t, t)` taking arguments :math:`X` and :math:`t`.
ode_weight (ndarray(n_block, n_bmeas, n_bstate)): Weight matrix defining the measure prior; :math:`W`.
ode_init (ndarray(n_block, n_bstate)): Initial value of the state variable :math:`X_t` at time :math:`t = a`.
t_min (float): First time point of the time interval to be evaluated; :math:`a`.
t_max (float): Last time point of the time interval to be evaluated; :math:`b`.
n_steps (int): Number of discretization points (:math:`N`) of the time interval that is evaluated, such that discretization timestep is :math:`dt = (b-a)/N`.
interrogate (Callable): Function defining the interrogation method.
prior_pars (tuple): A tuple containing the weight matrix and the variance matrix defining the solution prior; :math:`Q, R`.
obs_data (ndarray(n_obs, n_blocks, n_bobs)): Observed data; :math:`y_{0:M}`.
obs_times (ndarray(n_obs)): Observation time; :math:`0, \ldots, M`.
obs_weight (ndarray(n_obs, n_blocks, n_bobs, n_bstate)): Weight matrix in the observation model; :math:`D_{0:M}`.
obs_var (ndarry(n_obs, n_blocks, n_bobs, n_bobs)): Variance matrix in the observation model; :math:`\Omega_{0:M}`
kalman_type (str): Determine which type of Kalman (standard, square-root) to use.
params (kwargs): Optional model parameters.
Returns:
(float) : The loglikelihood of :math:`p(y_{0:M} \mid Z_{1:N})`.
"""
# standard or square-root filter
if kalman_type == "standard":
kalman_funs = standard
elif kalman_type == "square-root":
kalman_funs = square_root
else:
raise NotImplementedError
# prior variables
prior_weight, prior_var = prior_pars
# forward pass
filt_out = _solve_filter(
key=key,
ode_fun=ode_fun, ode_weight=ode_weight, ode_init=ode_init,
t_min=t_min, t_max=t_max, n_steps=n_steps,
interrogate=interrogate,
prior_weight=prior_weight, prior_var=prior_var,
kalman_funs=kalman_funs, **params
)
mean_state_pred, var_state_pred = filt_out["state_pred"]
mean_state_filt, var_state_filt = filt_out["state_filt"]
# backward pass
logdens, _ = _backward(
mean_state_filt=mean_state_filt,
var_state_filt=var_state_filt,
mean_state_pred=mean_state_pred,
var_state_pred=var_state_pred,
prior_weight=prior_weight,
prior_var=prior_var,
t_min=t_min, t_max=t_max, n_steps=n_steps,
obs_data=obs_data, obs_times=obs_times,
obs_weight=obs_weight, obs_var=obs_var,
kalman_funs=kalman_funs
)
return logdens
# --- ODE solver --------------------------------------------------------------
def _smooth_mv(state_par, kalman_funs):
r"""
Smoothing pass of the Fenrir algorithm used to compute solution posterior.
Args:
state_par (dict): Dictionary containing the weight, mean and variance matrices of the predicted and updated steps of the Kalman filter.
kalman_funs (object): An object that contains the Kalman filtering functions: predict, update and smooth.
Returns:
(tuple):
- **mean_state_smooth** (ndarray(n_steps+1, n_block, n_bstate)): Posterior mean of the solution process :math:`X_t` at times :math:`t \in [a, b]`.
- **var_state_smooth** (ndarray(n_steps+1, n_block, n_bstate, n_bstate)): Posterior variance of the solution process at times :math:`t \in [a, b]`.
"""
mean_state_pred, var_state_pred = state_par["state_pred"]
mean_state_filt, var_state_filt = state_par["state_filt"]
wgt_state = state_par["wgt_state"]
var_state = state_par["var_state"]
n_tot = mean_state_pred.shape[0]
# smooth pass
# lax.scan setup
def scan_fun(state_next, smooth_kwargs):
mean_state_filt = smooth_kwargs['mean_state_filt']
var_state_filt = smooth_kwargs['var_state_filt']
mean_state_pred = smooth_kwargs['mean_state_pred']
var_state_pred = smooth_kwargs['var_state_pred']
wgt_state = smooth_kwargs['wgt_state']
var_state = smooth_kwargs['var_state']
mean_state_curr, var_state_curr = jax.vmap(kalman_funs.smooth_mv)(
mean_state_next=state_next["mean"],
var_state_next=state_next["var"],
wgt_state=wgt_state,
mean_state_filt=mean_state_filt,
var_state_filt=var_state_filt,
mean_state_pred=mean_state_pred,
var_state_pred=var_state_pred,
var_state=var_state
)
state_curr = {
"mean": mean_state_curr,
"var": var_state_curr
}
return state_curr, state_curr
# initialize
scan_init = {
"mean": mean_state_filt[1],
"var": var_state_filt[1]
}
# scan arguments
scan_kwargs = {
'mean_state_filt': mean_state_filt[2:],
'var_state_filt': var_state_filt[2:],
'mean_state_pred': mean_state_pred[1:n_tot-1],
'var_state_pred': var_state_pred[1:n_tot-1],
'wgt_state': wgt_state[1:n_tot],
'var_state': var_state[1:n_tot]
}
# Note: initial value x0 is assumed to be known, so no need to smooth it
_, scan_out = jax.lax.scan(scan_fun, scan_init, scan_kwargs)
# append initial values to front and back
mean_state_smooth = jnp.concatenate(
[mean_state_filt[0:2], scan_out["mean"]]
)
var_state_smooth = jnp.concatenate(
[var_state_filt[0:2], scan_out["var"]]
)
return mean_state_smooth, var_state_smooth
[docs]
def solve_mv(key, ode_fun, ode_weight, ode_init,
t_min, t_max, n_steps,
interrogate,
prior_pars,
obs_data, obs_times, obs_weight, obs_var,
kalman_type="standard", **params):
r"""
Fenrir algorithm to compute the mean and variance of :math:`p(X_{0:N} \mid Z_{1:N}, Y_{0:M})`. Same arguments as :func:`fenrir`.
Returns:
(tuple):
- **mean_state_smooth** (ndarray(n_steps+1, n_block, n_bstate)): Posterior mean of the solution process :math:`X_t` at times :math:`t \in [a, b]`.
- **var_state_smooth** (ndarray(n_steps+1, n_block, n_bstate, n_bstate)): Posterior variance of the solution process at times :math:`t \in [a, b]`.
"""
# standard or square-root filter
if kalman_type == "standard":
kalman_funs = standard
elif kalman_type == "square-root":
kalman_funs = square_root
else:
raise NotImplementedError
prior_weight, prior_var = prior_pars
# forward pass
filt_out = _solve_filter(
key=key,
ode_fun=ode_fun, ode_weight=ode_weight, ode_init=ode_init,
t_min=t_min, t_max=t_max, n_steps=n_steps,
interrogate=interrogate,
prior_weight=prior_weight, prior_var=prior_var,
kalman_funs=kalman_funs, **params
)
mean_state_pred, var_state_pred = filt_out["state_pred"]
mean_state_filt, var_state_filt = filt_out["state_filt"]
# backward pass
_, state_par = _backward(
mean_state_filt=mean_state_filt,
var_state_filt=var_state_filt,
mean_state_pred=mean_state_pred,
var_state_pred=var_state_pred,
prior_weight=prior_weight,
prior_var=prior_var,
t_min=t_min, t_max=t_max, n_steps=n_steps,
obs_data=obs_data, obs_times=obs_times,
obs_weight=obs_weight, obs_var=obs_var,
kalman_funs=kalman_funs
)
mean_state_smooth, var_state_smooth = _smooth_mv(state_par, kalman_funs)
return mean_state_smooth, var_state_smooth