"""STA314 Homework 4.

Copyright and Usage Information
===============================

This file is provided solely for the personal and private use of students
taking STA314 at the University of Toronto St. George campus. All forms of
distribution of this code, whether as given or with any changes, are
expressly prohibited.
"""

from utils import *

import numpy as np
import scipy.special
import matplotlib.pyplot as plt

def compute_mean_mles(train_data, train_labels):
    '''
    Compute the mean estimate for each digit class. You may iterate over
    the possible label values (0 or 1 corresponding to digits "2" or "3"), 
    but otherwise make sure that your code is vectorized.

    Arguments
        train_data: size N x 256 numpy array with the images
        train_labels: size N numpy array with corresponding labels
    
    Returns
        means: size 2 x 256 numpy array with the ith row corresponding
               to the mean estimate for digit class i
    '''
    # Initialize array to store means
    means = np.zeros((2, 256))

    # == your code goes here ==
    # ====
    return means

def compute_sigma_mles(train_data, train_labels):
    '''
    Compute the covariance estimate for each digit class. You may iterate over
    the possible label values (0 or 1 corresponding to digits "2" or "3"), 
    but otherwise make sure that your code is vectorized.

    Arguments
        train_data: size N x 256 numpy array with the images
        train_labels: size N numpy array with corresponding labels
    
    Returns
        covariances: size 2 x 256 x 256 numpy array with the ith row corresponding
               to the covariance matrix estimate for label i
    '''
    # Initialize array to store covariances
    covariances = np.zeros((2, 256, 256))

    # == your code goes here ==
    # ====
    return covariances

def generative_likelihood(digits, means, covariances):
    '''
    Compute the generative log-likelihood log p(x|t). You may iterate over
    the possible label values (0 or 1 corresponding to digits "2" or "3"), 
    but otherwise make sure that your code is vectorized.

    Arguments
        digits: size N x 256 numpy array with the images
        means: size 2 x 256 numpy array with the 2 class means
        covariances: size 2 x 256 x 256 numpy array with the 2 class covariances
    
    Returns
        likelihoods: size N x 2 numpy array with the ith row corresponding
               to logp(x^(i) | t) for t in {0, 1}
    '''
    N = digits.shape[0]
    likelihoods = np.zeros((N, 2))
    # == your code goes here ==
    # ====
    return likelihoods


def conditional_likelihood(digits, means, covariances):
    '''
    Compute the generative log-likelihood log p(t|x). Make sure that your code
    is vectorized. Do not iterate over the label values explicitly in Python.

    Arguments
        digits: size N x 256 numpy array with the images
        means: size 2 x 256 numpy array with the 2 class means
        covariances: size 2 x 256 x 256 numpy array with the 2 class covariances

    Returns
        likelihoods: size N x 2 numpy array with the ith row corresponding
               to logp(t | x^(i)) for t in {0, 1}
    '''

    # == your code goes here ==
    # ====
    pass

def classify_data(digits, means, covariances):
    '''
    Classify new points by taking the most likely posterior class. 
    Make sure that your code is vectorized. Do not iterate over 
    the label values explicitly in Python.

    Arguments
        digits: size N x 256 numpy array with the images
        means: size 2 x 256 numpy array with the 2 class means
        covariances: size 2 x 256 x 256 numpy array with the 2 class covariances
    
    Returns
        pred: size N numpy array with the ith element corresponding
               to argmax_t log p(t | x^(i))
    '''
    # Compute and return the most likely class
    # == your code goes here ==
    # ====
    pass

def avg_conditional_likelihood(digits, labels, means, covariances):
    '''
    Compute the average conditional likelihood over the true class labels

        AVG( log p(t^(i) | x^(i)) )

    i.e. the average log likelihood that the model assigns to the correct class label.

    Arguments
        digits: size N x 256 numpy array with the images
        labels: size N numpy array with the labels
        means: size 2 x 256 numpy array with the 2 class means
        covariances: size 2 x 256 x 256 numpy array with the 2 class covariances
    
    Returns
        average conditional log-likelihood.
    '''
    cond_likelihood = conditional_likelihood(digits, means, covariances)

    # Compute as described above and return
    assert len(digits) == len(labels)
    sample_size = len(digits)
    total_prob = 0
    for j in range(sample_size):
        label = int(labels[j])
        total_prob += cond_likelihood[j][label]
    return total_prob/sample_size


def main():
    x_train, y_train = load_train()
    x_test, y_test = load_test()
    y_train, y_test = y_train.flatten(), y_test.flatten()

    # Fit the model
    means = compute_mean_mles(x_train, y_train)
    covariances = compute_sigma_mles(x_train, y_train)

    # Evaluation
    train_log_llh = avg_conditional_likelihood(x_train, y_train, means, covariances)
    test_log_llh = avg_conditional_likelihood(x_test, y_test, means, covariances)

    print('Train average conditional log-likelihood: ', train_log_llh)
    print('Test average conditional log-likelihood: ', test_log_llh)

    train_posterior_result = classify_data(x_train, means, covariances)
    test_posterior_result = classify_data(x_test, means, covariances)

    train_accuracy = np.mean(y_train.astype(int) == train_posterior_result)
    test_accuracy = np.mean(y_test.astype(int) == test_posterior_result)

    print('Train posterior accuracy: ', train_accuracy)
    print('Test posterior accuracy: ', test_accuracy)

if __name__ == '__main__':
    main()
