from jax import jvp, numpy as np
import scipy.integrate

# use double precision
from jax.config import config
config.update('jax_enable_x64', True)



def grassmannian_length(x_fn, theta_min, theta_max):
    """Compute the Grassmannian length of the curve represented by the
    function x_fn, for theta in the interval [theta_min, theta_max]."""
    
    pass   # YOUR CODE HERE (you are welcome to define helper functions)




#### Code below is used for testing your implementation. Please do not modify.

def archimedean_spiral():
    def spiral_fn(t):
        return t * np.array([np.cos(t), np.sin(t)])
    return spiral_fn

def log_spiral(b):
    def spiral_fn(t):
        return np.exp(b*t) * np.array([np.cos(t), np.sin(t)])
    return spiral_fn

def helix(rad):
    def helix_fn(t):
        return np.array([t, rad*np.cos(t), rad*np.sin(t)])
    return helix_fn

def compute_lengths():
    print('val1:', grassmannian_length(archimedean_spiral(), 0, 4*np.pi))
    print('val2:', grassmannian_length(log_spiral(0.5), 0, 4*np.pi))
    print('val3:', grassmannian_length(helix(0.5), 0, 4*np.pi))
    print('val4:', grassmannian_length(helix(2), 0, 4*np.pi))

