Source code for rodeo.utils

r"""
Util functions for rodeo.

"""
import jax
import jax.numpy as jnp
import jax.scipy as jsp


[docs] def add_sqrt(sqrt_A, sqrt_B): r""" Transforms the square roots of matrices A and B into the square root of their sum. Args: sqrt_A (ndarray(n_dim, n_dim)): The square root of matrix A. sqrt_B (ndarray(n_dim, n_dim)): The square root of matrix B. Returns: (ndarray(n_dim, n_dim)): The square root of the sum of matrices A and B. """ sqrt_sum = jnp.vstack([sqrt_A.T,sqrt_B.T]) Q,R = jnp.linalg.qr(sqrt_sum) return R.T
[docs] def mvncond(mu, Sigma, icond): """ Calculates A, b, and V such that :math:`y[!icond] | y[icond] \sim \operatorname{Normal}(A y[icond] + b, V)`. Args: mu (ndarray(2*n_dim)): Mean of y. Sigma (ndarray(2*n_dim, 2*n_dim)): Covariance of y. icond (ndarray(2*nd_dim)): Conditioning on the terms given. Returns: (tuple): - **A** (ndarray(n_dim, n_dim)): For :math:`y \sim \operatorname{Normal}(\mu, \Sigma)` such that :math:`y[!icond] | y[icond] \sim \operatorname{Normal}(A y[icond] + b, V)` Calculate A. - **b** (ndarray(n_dim)): For :math:`y \sim \operatorname{Normal}(\mu, \Sigma)` such that :math:`y[!icond] | y[icond] \sim \operatorname{Normal}(A y[icond] + b, V)` Calculate b. - **V** (ndarray(n_dim, n_dim)): For :math:`y \sim \operatorname{Normal}(\mu, \Sigma)` such that :math:`y[!icond] | y[icond] \sim \operatorname{Normal}(A y[icond] + b, V)` Calculate V. """ # if y1 = y[~icond] and y2 = y[icond], should have A = Sigma12 * Sigma22^{-1} ficond = jnp.nonzero(~icond) ticond = jnp.nonzero(icond) # A = jnp.dot(Sigma[jnp.ix_(ficond[0], ticond[0])], jsp.cho_solve( # jsp.cho_factor(Sigma[jnp.ix_(ticond[0], ticond[0])]), jnp.identity(sum(icond)))) A = jnp.dot(Sigma[jnp.ix_(ficond[0], ticond[0])], solve_var(Sigma[jnp.ix_(ticond[0], ticond[0])], jnp.identity(jnp.sum(icond))) ) b = mu[~icond] - jnp.dot(A, mu[icond]) # mu1 - A * mu2 V = Sigma[jnp.ix_(ficond[0], ficond[0])] - jnp.dot(A, Sigma[jnp.ix_(ticond[0], ficond[0])]) # Sigma11 - A * Sigma21 return A, b, V
[docs] def multivariate_normal_logpdf(x, mean, cov): r"""Using eigendecomposition to compute multivariate normal logpdf. Args: x (ndarray(p)): Observations. mean (ndarray(p)): Mean of the distribution. cov (ndarray(p, p)): Symmetric positive (semi)definite covariance matrix of the distribution. Returns: (float): The logpdf of the multivariate normal. """ w, v = jnp.linalg.eigh(cov) z = jnp.dot(v.T, x - mean) z2 = z**2 iw = ~jnp.isclose(w, 0, rtol=1e-300) w = jnp.where(iw, w, 1.) # remove possibility of nan val = z2/w + jnp.log(w) val = -.5 * jnp.sum(jnp.where(iw, val, 0.)) - jnp.sum(iw)*.5*jnp.log(2*jnp.pi) return val
[docs] def first_order_pad(ode_fun, n_vars, n_deriv): r""" Returns the W matrix, and a function for finding the initial value for given :math:`\theta`. Args: ode_fun (Callable): ODE function. n_vars (int): Number of variables. n_deriv (int): Number of upper derivatives to use. Returns: (tuple): - **W** (ndarray(n_var, 1, n_deriv)): W matrix defining the left hand side of the ODE. - **ode_init** (Callable): A function that helps setting up the initial state space mean. """ def ode_init(x0, t, **params): x0 = x0[:, None] return jnp.hstack([x0, ode_fun(x0, t, **params), jnp.zeros((n_vars, n_deriv-2))]) W = jnp.zeros((n_vars, 1, n_deriv)) W = W.at[:, :, 1].set(1.0) return W, ode_init
[docs] def solve_var(V, B): r""" Computes :math:`X = V^{-1}B`, where :math:`V` is a variance matrix. Args: V (ndarray(n_dim1, n_dim1)): Variance matrix :math:`V`. B (ndarray(n_dim1, n_dim2)): Matrix :math:`B`. Returns: (ndarray(n_dim1, n_dim2)): Matrix :math:`X = V^{-1}B`. """ # L, low = jsp.linalg.cho_factor(V) # return jsp.linalg.cho_solve((L, low), B) return jnp.linalg.solve(V, B)