from jax import numpy as np
from jax import jvp, grad
from matplotlib import pyplot as plt
import numpy as onp
import scipy.integrate, scipy.optimize


###### Some useful functions

def gaussian_pdf(theta, x):
    """PDF of a Gaussian distribution with parameters
    theta=(mu, sigma_sq)."""
    mu, sigma = theta
    return -0.5 * np.log(2*np.pi * sigma**2) + \
        -0.5 * (x - mu)**2 / sigma**2

def gaussian_kl(theta1, theta2):
    """KL divergence between two univariate Gaussians with
    parameters theta1=(mu1, sigma1) and theta2=(mu2, sigma2)."""
    (mu1, sigma1), (mu2, sigma2) = theta1, theta2
    return 0.5 * np.log(sigma2**2/sigma1**2) + \
        0.5 * (sigma1**2 + (mu1 - mu2)**2) / sigma2**2 + \
        -0.5

def gaussian_sample(theta):
    """Generate a sample from a Gaussian distribution with
    parameters theta=(mu, sigma_sq)."""
    mu, sigma = theta
    return onp.random.normal(mu, sigma)

def std2moments(theta):
    """Convert from standard form to moments form."""
    mu, sigma = theta
    s = -0.5 * (mu**2 + sigma**2)
    return np.array([mu, s])

def moments2std(xi):
    """Convert from moments form to standard form."""
    mu, s = xi
    sigma = np.sqrt(-2 * s - mu**2)
    return np.array([mu, sigma])

def std2info(theta):
    """Convert from standard form to information form."""
    mu, sigma = theta
    lmbda = 1 / sigma**2
    h = lmbda * mu
    return np.array([h, lmbda])
    
def info2std(eta):
    """Convert from information form to standard form."""
    h, lmbda = eta
    sigma = np.sqrt(1 / lmbda)
    mu = h / lmbda
    return np.array([mu, sigma])

def hvp(J, w, v):
    """Compute the Hessian-vector product. J is the cost function,
    w is the evaluation point, and v is the direction."""
    return jvp(grad(J), (w,), (v,))[1]

def fisher_mvp(theta, dtheta):
    """Compute the matrix-vector product with the Fisher information matrix."""
    rho = lambda th: gaussian_kl(theta, th)
    return hvp(rho, theta, dtheta)

def interp(x, xp, fp):
    """Compute a linear interpolation."""
    i = np.clip(np.searchsorted(xp, x, side='right'), 1, len(xp) - 1)
    return (fp[i - 1] *  (xp[i] - x) + fp[i] * (x - xp[i - 1])) / (xp[i] - xp[i - 1])

def linear_path(theta0, theta1):
    """The linear path."""
    return lambda t: (1-t) * theta0 + t * theta1

def geometric_averages_path(theta0, theta1):
    """The geometric averages path."""
    return lambda t: info2std((1-t) * std2info(theta0) + t * std2info(theta1))

def moment_averages_path(theta0, theta1):
    """The moment averages path."""
    return lambda t: moments2std((1-t) * std2moments(theta0) + t * std2moments(theta1))

def spline_path(theta0, theta1, knots):
    """The linear spline path for a given set of knots. Knots is an N x 2 matrix."""
    N = knots.shape[0]
    tp = np.linspace(0, 1, N+2)
    xp = np.concatenate([np.array([theta0[0]]), knots[:,0], np.array([theta1[0]])])
    yp = np.concatenate([np.array([theta0[1]]), knots[:,1], np.array([theta1[1]])])
    
    def spline_fn(t):
        return np.array([interp(t, tp, xp), interp(t, tp, yp)])
    
    return spline_fn


###### Your code here

def path_energy(path_fn):
    pass

def optimize_spline_path(theta0, theta1, N):
    pass


###### Generating the outputs

def compute_path_energies(include_spline=False):
    THETA0 = np.array([0., 1.])
    THETA1 = np.array([5., 2.])
    
    print('Linear: {:.3f}'.format(path_energy(linear_path(THETA0, THETA1))))
    print('Geometric: {:.3f}'.format(path_energy(geometric_averages_path(THETA0, THETA1))))
    print('Moments: {:.3f}'.format(path_energy(moment_averages_path(THETA0, THETA1))))

    if include_spline:
        knots = optimize_spline_path(THETA0, THETA1, 3)
        spath = spline_path(THETA0, THETA1, knots)
        print('Spline: {:.3f}'.format(path_energy(spath)))

def plot_path(path, color):
    thetas = np.array([path(t) for t in np.linspace(0, 1, 101)])
    plt.plot(thetas[:,0], thetas[:,1], color)
    thetas = np.array([path(t) for t in np.linspace(0, 1, 11)])
    plt.plot(thetas[:,0], thetas[:,1], 'kx')

def plot_paths(include_spline=False):
    THETA0 = np.array([0., 1.])
    THETA1 = np.array([5., 2.])
    NUM_KNOTS = 3
    
    plt.figure()
    plot_path(linear_path(THETA0, THETA1), 'c')
    plot_path(geometric_averages_path(THETA0, THETA1), 'b')
    plot_path(moment_averages_path(THETA0, THETA1), 'g')

    if include_spline:
        knots = optimize_spline_path(THETA0, THETA1, NUM_KNOTS)
        spath = spline_path(THETA0, THETA1, knots)
        plot_path(spath, 'r')
