import pickle
import numpy
import re

import MySQLdb

class Bracket():

    def __init__(self):
        self.round = 1

        if self.round == 1:
            f = open('2009_bracket.txt')            
            self.bracket = []
        elif self.round == 2:
            f = open('2009_bracket_rd2.txt')
            self.round2 = []

        non_white_re = re.compile(r'\w+')

        for line in f:
            line = line.strip()
            m = non_white_re.search(line)
            if m is not None:
                if self.round == 1:
                    self.bracket.append(line)
                elif self.round == 2:
                    self.round2.append(line)

        self.dim = 1
        self.offenses = numpy.load("models/real_%sd_offenses.pickle" % self.dim)
        self.defenses = numpy.load("models/real_%sd_defenses.pickle" % self.dim)

        host = "localhost"
        user = "root"
        passwd = ""
        db_name = "march_madness"
        
        self.db = MySQLdb.connect(host=host, user=user, passwd=passwd, db=db_name)
        self.cursor = self.db.cursor()        
        

    def simulate_game(self, team1_code, team2_code):
        team1_id = self.team_code_to_id(team1_code)
        team2_id = self.team_code_to_id(team2_code)

        s_hat1 = numpy.sum(self.offenses[team1_id] * self.defenses[team2_id])
        s_hat2 = numpy.sum(self.offenses[team2_id] * self.defenses[team1_id])        

        team1_name = self.team_code_to_name(team1_code)
        team2_name = self.team_code_to_name(team2_code)        

        print "%s %s, %s %s" % (team1_name, s_hat1, team2_name, s_hat2)
        return (s_hat1, s_hat2)
    

    def top_defenses(self, num=50):
        self.defenses[0] = 1000
        ordered_d = list(index for index, item in sorted(enumerate(self.defenses),
                                                         key=lambda item: item[1]))        
        result = []
        for i in ordered_d[0:num]:
            result.append((self.team_id_to_name(i), self.defenses[i][0]))

        return result


    def top_offenses(self, num=50):
        self.offenses[0] = 0

        # to print combined measure instead
        #self.offenses = self.offenses - self.defenses 

        ordered_o = list(index for index, item in sorted(enumerate(self.offenses),
                                                         key=lambda item: item[1]))        
        ordered_o.reverse()
        result = []
        for i in ordered_o[0:num]:
            result.append((self.team_id_to_name(i), self.offenses[i][0]))
        
        return result
    

    def team_code_to_id(self, code):

        self.cursor.execute("""
        SELECT team_id FROM team_code WHERE team_code='%s'""" % code)

        result = self.cursor.fetchone()

        return result[0]

    def team_code_to_name(self, code):

        self.cursor.execute("""
        SELECT team_name FROM team_code WHERE team_code='%s'""" % code)

        result = self.cursor.fetchone()

        return result[0]

    def team_id_to_name(self, id):
        self.cursor.execute("""
        SELECT team_name FROM team_code WHERE team_id=%s""" % id)
        result = self.cursor.fetchone()

        return result[0]
    

    def play_round1(self):
        print 20*"=", "ROUND 1", 20*"="
        self.round2 = []
        for i in range(32):
            team1 = self.bracket[2 * i]
            team2 = self.bracket[2 * i + 1] 

            (s_hat1, s_hat2) = self.simulate_game(team1, team2)

            if s_hat1 > s_hat2:
                self.round2.append(team1)
            else:
                self.round2.append(team2)

            if i % 8 == 7:
                print

    def play_round2(self):
        print 20*"=", "ROUND 2", 20*"="
        self.round3 = []
        for i in range(16):
            team1 = self.round2[2 * i]
            team2 = self.round2[2 * i + 1] 

            (s_hat1, s_hat2) = self.simulate_game(team1, team2)

            if s_hat1 > s_hat2:
                self.round3.append(team1)
            else:
                self.round3.append(team2)

            if i % 4 == 3:
                print


    def play_round3(self):
        print 20*"=", "ROUND 3", 20*"="
        self.round4 = []
        for i in range(8):
            team1 = self.round3[2 * i]
            team2 = self.round3[2 * i + 1] 

            (s_hat1, s_hat2) = self.simulate_game(team1, team2)

            if s_hat1 > s_hat2:
                self.round4.append(team1)
            else:
                self.round4.append(team2)

            if i % 2 == 1:
                print

                                
    def play_round4(self):
        print 20*"=", "ROUND 4", 20*"="
        self.round5 = []
        for i in range(4):
            team1 = self.round4[2 * i]
            team2 = self.round4[2 * i + 1] 

            (s_hat1, s_hat2) = self.simulate_game(team1, team2)

            if s_hat1 > s_hat2:
                self.round5.append(team1)
            else:
                self.round5.append(team2)

            print


    def play_round5(self):
        print 20*"=", "ROUND 5", 20*"="
        self.round6 = []
        for i in range(2):
            team1 = self.round5[2 * i]
            team2 = self.round5[2 * i + 1] 

            (s_hat1, s_hat2) = self.simulate_game(team1, team2)

            if s_hat1 > s_hat2:
                self.round6.append(team1)
            else:
                self.round6.append(team2)

            print


    def play_round6(self):
        print 20*"=", "ROUND 6", 20*"="
        self.round7 = []
        for i in range(1):
            team1 = self.round6[2 * i]
            team2 = self.round6[2 * i + 1] 

            (s_hat1, s_hat2) = self.simulate_game(team1, team2)

            if s_hat1 > s_hat2:
                self.round7.append(team1)
            else:
                self.round7.append(team2)

            print

if __name__ == '__main__':
    b = Bracket()

    # If we start from the second round, don't play the first.
    if False:
        if b.round <= 1:
            b.play_round1()

        b.play_round2()
        b.play_round3()
        b.play_round4()
        b.play_round5()
        b.play_round6()

    if b.dim == 1:
        print "Offenses"
        for row in b.top_offenses():
            print row[0], "(%.2f)" % float(row[1])
        print
        print "Defenses"
        for row in b.top_defenses():
            print row[0], "(%.2f)" % float(row[1])


    # If you want to simulate an individual game, you need Yahoo's team code,
    # then just use the simulate_game medthod.
    #b.simulate_game('mbp', 'aah')
    #b.simulate_game('nav', 'laq')
    #b.simulate_game('nav', 'mbp')
    #b.simulate_game('nav', 'aah')    

    
