Source code for rodeo.prior.ibm

r"""
Computes the initial parameters for the process prior using the q times integrated Brownian motion (IBM)

.. math::

    x^q(t) = \sigma B(t).

It has the analytical formulas for the parameters

.. math::

    Q_{ij} = 𝟙_{i\leq j}\frac{(\Delta t)^{j-i}}{(j-i)!}, \qquad R_{ij} = \sigma^2 \frac{(\Delta t)^{2q-1-i-j}}{(2q-1-i-j)(q-i)!(q-j)!}.

"""


import jax
import jax.numpy as jnp
import jax.scipy as jsp

def _factorial(x):
    """
    JAX factorial function.

    It's actually the gamma function shifted such that `_factorial(x) = x!` for integer-valued `x`.

    Args:
        x (int): Integer.
    
    Returns:
        (int): Factorial of x.

    """
    return jnp.exp(jsp.special.gammaln(x+1.0))


[docs] def ibm_state(dt, q, sigma): """ Calculate the state weight matrix and variance matrix of q-times integrated Brownian motion. Args: dt (float): The step size between simulation points. q (int): The number of times to integrate the underlying Brownian motion. sigma (float): Parameter in the q-times integrated Brownian Motion. Returns: (tuple): - **Q** (ndarray(q+1, q+1)): The state weight matrix defined in Kalman solver. - **R** (ndarray(q+1, q+1)): The state variance matrix defined in Kalman solver. """ I, J = jnp.meshgrid(jnp.arange(q+1), jnp.arange(q+1), indexing="ij", sparse=True) mesh = J-I Q = jnp.nan_to_num(dt**mesh/_factorial(mesh), 0) mesh = (2.0*q+1.0) - I - J num = dt**mesh den = mesh * _factorial(q - I) * _factorial(q-J) R = sigma**2 * num/den return Q, R
[docs] def ibm_init(dt, n_deriv, sigma): """ Calculates the initial parameters necessary for the Kalman solver with the p-1 times integrated Brownian Motion, where the ODE is up to the p-1th derivative. Args: dt (float): The step size between simulation points. n_deriv (int): Dimension of the prior. sigma (ndarray(n_block)): Parameter in variance matrix. Returns: (tuple): - **wgt_state** (ndarray(n_block, p, p)): Weight matrix defining the solution prior; :math:`Q`. - **var_state** (ndarray(n_block, p, p)): Variance matrix defining the solution prior; :math:`R`. """ n_block = len(sigma) wgt_state_i, var_state_i = ibm_state(dt, n_deriv-1, 1) wgt_state = jnp.repeat(wgt_state_i[None,:,:], repeats=n_block, axis=0) var_state = jax.vmap(lambda b: sigma[b]**2 * var_state_i)(jnp.arange(n_block)) return wgt_state, var_state