r"""
We implement some ready to use interrogation methods from our paper. We implement the interrogation method of Chkrebtii et al (2016), Schober et al (2019) and Kramer et al (2021).
We also implement two other interrogation methods corresponding to Schober et al (2019) and Kramer et al (2021) where we instead use the variance of Chkrebtii et al (2016).
"""
import jax
import jax.numpy as jnp
[docs]
def interrogate_chkrebtii(key, ode_fun, ode_weight, t,
mean_state_pred, var_state_pred,
kalman_type, **params):
r"""
Interrogate method of Chkrebtii et al (2016); DOI: 10.1214/16-BA1017.
Same arguments and returns as :func:`~ode.interrogate_rodeo`.
"""
n_block, n_bstate = mean_state_pred.shape
key, *subkeys = jax.random.split(key, num=n_block+1)
subkeys = jnp.array(subkeys)
if kalman_type == "standard":
var_meas = jax.vmap(lambda wm, vsp:
jnp.atleast_2d(jnp.linalg.multi_dot([wm, vsp, wm.T])))(
ode_weight, var_state_pred
)
x_state = jax.vmap(jax.random.multivariate_normal)(
subkeys,
mean_state_pred,
var_state_pred
)
elif kalman_type == "square-root":
var_meas = jax.vmap(lambda wm, vsp:
jnp.atleast_2d(jnp.linalg.multi_dot([wm, vsp])))(
ode_weight, var_state_pred
)
random_norm = jax.vmap(jax.random.normal, in_axes=(0, None))(subkeys, (n_bstate,))
x_state = jax.vmap(lambda b:
mean_state_pred[b] + var_meas[b].dot(random_norm[b]))(jnp.arange(n_block))
else:
raise NotImplementedError
mean_meas = -ode_fun(x_state, t, **params)
return jnp.zeros(ode_weight.shape), mean_meas, var_meas
[docs]
def interrogate_schober(key, ode_fun, ode_weight, t,
mean_state_pred, var_state_pred,
**params):
r"""
Interrogate method of Schober et al (2019); DOI: https://doi.org/10.1007/s11222-017-9798-7.
Same arguments and returns as :func:`~ode.interrogate_rodeo`.
"""
n_block, n_bmeas, _ = ode_weight.shape
var_meas = jnp.zeros((n_block, n_bmeas, n_bmeas))
mean_meas = -ode_fun(mean_state_pred, t, **params)
return jnp.zeros(ode_weight.shape), mean_meas, var_meas
[docs]
def interrogate_kramer(key, ode_fun, ode_weight, t,
mean_state_pred, var_state_pred,
**params):
r"""
First order interrogate method of Kramer et al (2021); DOI: https://doi.org/10.48550/arXiv.2110.11812.
Assumes off block diagonals are zero.
Same arguments and returns as :func:`~ode.interrogate_rodeo`.
"""
n_block, n_bmeas, n_bstate = ode_weight.shape
fun_meas = -ode_fun(mean_state_pred, t, **params)
jac = jax.jacfwd(ode_fun)(mean_state_pred, t, **params)
# need to get the diagonal of jac
jac = jax.vmap(lambda b:
jac[b, :, b])(jnp.arange(n_block))
wgt_meas = -jac
mean_meas = jax.vmap(lambda b:
fun_meas[b] + jac[b].dot(mean_state_pred[b]))(jnp.arange(n_block))
var_meas = jnp.zeros((n_block, n_bmeas, n_bmeas))
return wgt_meas, mean_meas, var_meas
[docs]
def interrogate_rodeo(key, ode_fun, ode_weight, t,
mean_state_pred, var_state_pred,
**params):
r"""
Rodeo interrogation method.
Args:
key (PRNGKey): Jax PRNG key.
ode_fun (Callable): Higher order ODE Callable :math:`W X_t = f(X_t, t, \theta)` taking arguments :math:`X` and :math:`t`.
ode_weight (ndarray(n_block, n_bmeas, n_bstate)): Weight matrix.
t (float): Time point.
mean_state_pred (ndarray(n_block, n_bstate)): Mean estimate for state at time t given observations from times [a...t-1]; denoted by :math:`\mu_{t|t-1}`.
var_state_pred (ndarray(n_block, n_bstate, n_bstate)): Covariance of estimate for state at time t given observations from times [a...t-1]; denoted by :math:`\Sigma_{t|t-1}`.
params : Optional model parameters.
Returns:
(tuple):
- **wgt_meas** (ndarray(n_block, n_bmeas, n_bstate)): Interrogation weight matrix.
- **mean_meas** (ndarray(n_block, n_bmeas)): Interrogation offset.
- **var_meas** (ndarray(n_block, n_bmeas, n_bmeas)): Interrogation variance.
"""
n_block = mean_state_pred.shape[0]
var_meas = jax.vmap(lambda wm, vsp:
jnp.atleast_2d(jnp.linalg.multi_dot([wm, vsp, wm.T])))(
ode_weight, var_state_pred
)
mean_meas = -ode_fun(mean_state_pred, t, **params)
return jnp.zeros(ode_weight.shape), mean_meas, var_meas
# def interrogate_rodeo2(key, ode_fun, ode_weight, t,
# mean_state_pred, var_state_pred,
# **params):
# r"""
# First order interrogate method of Kramer et al (2021); DOI: https://doi.org/10.48550/arXiv.2110.11812.
# Assumes off block diagonals are zero.
# Same arguments and returns as :func:`~ode.interrogate_rodeo`.
# """
# n_block, n_bmeas, n_bstate = ode_weight.shape
# fun_meas = -ode_fun(mean_state_pred, t, **params)
# jac = jax.jacfwd(ode_fun)(mean_state_pred, t, **params)
# # need to get the diagonal of jac
# jac = jax.vmap(lambda b:
# jac[b, :, b])(jnp.arange(n_block))
# wgt_meas = -jac
# mean_meas = jax.vmap(lambda b:
# fun_meas[b] + jac[b].dot(mean_state_pred[b]))(jnp.arange(n_block))
# var_meas = jax.vmap(lambda wm, vsp:
# jnp.atleast_2d(jnp.linalg.multi_dot([wm, vsp, wm.T])))(
# wgt_meas, var_state_pred
# )
# return wgt_meas, mean_meas, var_meas