import pylab
import numpy
import pickle

import MySQLdb

SCORES_FILE = "all_scores.pickle"
TEST_SCORES_FILE = "test_scores.pickle"


class Learner():

    def __init__(self, score_tuples, latent_d=1):
        self.latent_d = latent_d
        self.learning_rate = .0001
        self.regularization_strength = .1
        
        self.scores = numpy.array(score_tuples).astype(float)
        self.converged = False

        self.num_teams = int(numpy.max(self.scores[:,0]) + 1)
        
        normalize_scores = False
        if normalize_scores:
            for i in range(len(score_tuples)):
                total_score = float(self.scores[i, 1] + self.scores[i, 3])
                print "total score", total_score
                self.scores[i, 1] = 140.0 * float(self.scores[i, 1]) / total_score
                self.scores[i, 3] = 140.0 * float(self.scores[i, 3]) / total_score

                print self.scores[i]

        print (self.num_teams, self.latent_d)
        print self.scores

        self.offenses = numpy.random.random((self.num_teams, self.latent_d))
        self.defenses = numpy.random.random((self.num_teams, self.latent_d))

        self.new_offenses = numpy.random.random((self.num_teams, self.latent_d))
        self.new_defenses = numpy.random.random((self.num_teams, self.latent_d))           


    def likelihood(self, offenses=None, defenses=None):
        if offenses is None:
            offenses = self.offenses
        if defenses is None:
            defenses = self.defenses
            
        sq_error = 0
        
        for score_tuple in self.scores:
            (team_i, s_i, team_j, s_j, weight) = score_tuple

            s_hat_i = numpy.sum(offenses[team_i] * defenses[team_j])
            s_hat_j = numpy.sum(offenses[team_j] * defenses[team_i])

            sq_error += weight * (s_i - s_hat_i)**2 + weight * (s_j - s_hat_j)**2

        L2_norm = 0
        for i in range(self.num_teams):
            for d in range(self.latent_d):
                L2_norm += defenses[i, d]**2
                L2_norm += offenses[i, d]**2                

        return -sq_error - self.regularization_strength * L2_norm
        
        
    def update(self):

        regularizers_o = numpy.zeros(self.num_teams)
        regularizers_d = numpy.zeros(self.num_teams)

        updates_o = numpy.zeros((self.num_teams, self.latent_d))
        updates_d = numpy.zeros((self.num_teams, self.latent_d))        

        for score_tuple in self.scores:
            #print score_tuple
            (team_i, s_i, team_j, s_j, weight) = score_tuple

            s_hat_i = numpy.sum(self.offenses[team_i] * self.defenses[team_j])
            s_hat_j = numpy.sum(self.offenses[team_j] * self.defenses[team_i])            
            
            for d in range(self.latent_d):
                #print team_i, team_j, d, self.defenses[team_j, d]
                
                updates_o[team_i, d] += self.defenses[team_j, d] * (s_i - s_hat_i) * weight
                updates_d[team_j, d] += self.offenses[team_i, d] * (s_i - s_hat_i) * weight

                updates_o[team_j, d] += self.defenses[team_i, d] * (s_j - s_hat_j) * weight
                updates_d[team_i, d] += self.offenses[team_j, d] * (s_j - s_hat_j) * weight         

        while (not self.converged):
            initial_lik = self.likelihood()

            print "  setting learning rate =", self.learning_rate
            self.try_updates(updates_o, updates_d)

            final_lik = self.likelihood(self.new_offenses, self.new_defenses)

            if final_lik > initial_lik:
                self.apply_updates(updates_o, updates_d)
                self.learning_rate *= 1.25

                if final_lik - initial_lik < .001:
                    self.converged = True
                    
                break
            else:
                self.learning_rate *= .5
                self.undo_updates()

            if self.learning_rate < 1e-10:
                self.converged = True

        return not self.converged
    

    def apply_updates(self, updates_o, updates_d):
        for i in range(self.num_teams):
            for d in range(self.latent_d):
                self.offenses[i, d] = self.new_offenses[i, d]
                self.defenses[i, d] = self.new_defenses[i, d]                

    
    def try_updates(self, updates_o, updates_d):        
        alpha = self.learning_rate
        beta = self.regularization_strength

        for i in range(self.num_teams):
            for d in range(self.latent_d):
                self.new_offenses[i, d] = self.offenses[i, d] + \
                                       alpha * (beta * self.offenses[i, d] + updates_o[i, d])
                self.new_defenses[i, d] = self.defenses[i, d] + \
                                       alpha * (beta * self.defenses[i, d] + updates_d[i, d])
        

    def undo_updates(self):
        # Don't need to do anything here
        pass


    def print_latent_vectors(self):
        print "Offenses"
        for i in range(self.num_teams):
            print i,
            for d in range(self.latent_d):
                print self.offenses[i, d],
            print
            
        print "Defenses"
        for i in range(self.num_teams):
            print i,
            for d in range(self.latent_d):
                print self.defenses[i, d],
            print    


    def save_latent_vectors(self, prefix):
        self.offenses.dump(prefix + "%sd_offenses.pickle" % self.latent_d)
        self.defenses.dump(prefix + "%sd_defenses.pickle" % self.latent_d)
    

def scores_from_file():
    f = open(SCORES_FILE)
    scores = pickle.load(f)
    f.close()

    return scores


def scores_from_db(save_to_file=True):
    host = "localhost"
    user = "root"
    passwd = ""
    db_name = "march_madness"
    
    db = MySQLdb.connect(host=host, user=user, passwd=passwd, db=db_name)
    cursor = db.cursor()    

    weight_query = "IF(DATEDIFF(gr.date_played, DATE('2009-03-17')) > 0, 1, 1)"
    cursor.execute("""
    SELECT tc1.team_id, gr.home_score, tc2.team_id, gr.away_score, %s
    FROM game_result gr, team_code tc1, team_code tc2
    WHERE gr.home_code = tc1.team_code AND gr.away_code = tc2.team_code;
    """ % (weight_query))

    scores = cursor.fetchall()

    if save_to_file:
        f = open(SCORES_FILE, 'w')
        pickle.dump(scores, f)
        f.close()

    return scores


def scores_from_csv():
    filename = "2011_season_results.csv"
    
    # TODO
    
    

def fake_scores():
    o = []
    d = []

    scores = []
    
    num_teams = 100
    num_games = 30
    latent_dimension = 10
    
    # Generate the latent offensive and defensive vectors
    for i in range(num_teams):
        o.append(5 * numpy.random.rand(latent_dimension))
        d.append(5 * numpy.random.rand(latent_dimension))
        
    # Play 30 weeks worth of games
    for w in range(num_games):
        matchups = numpy.random.permutation(num_teams)
        
        # Pair teams by playing 0 vs 1, 2 vs 3, ..., N-2 vs N-1
        for i in range(num_teams / 2):
            team1 = matchups[2 * i]
            team2 = matchups[2 * i + 1]
            
            score1 = numpy.sum(o[team1] * d[team2])
            score2 = numpy.sum(o[team2] * d[team1])                
        
            scores.append((team1, score1, team2, score2))

    return (scores, o, d)


def plot_scores(scores):
    xs = []
    ys = []
    
    for i in range(len(scores)):
        xs.append(scores[i][1])
        ys.append(scores[i][3])        
        
    pylab.plot(xs, ys, 'bx')
    pylab.show()


def plot_latent_vectors(U):
    pass


if __name__ == "__main__":

    DATASET = 'real'

    if DATASET == 'fake':
        (scores, true_o, true_d) = fake_scores()
    elif DATASET == 'real':
        scores = scores_from_db()
        #scores = scores_from_file()

    #plot_scores(scores)

    l = Learner(scores, latent_d=1)

    while (l.update()):
        print "L=", l.likelihood()
        pass
    
    l.print_latent_vectors()
    l.save_latent_vectors("models/%s_" % DATASET)
    

