""" 
Copyright (c) 2010, Roland Memisevic and Josh Susskind 
All rights reserved.

Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:

* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.







Cuda-based implementation of a gated Boltzmann machine (GBM). 

See the bottom of this file for an example how to use this module. 

A Gated Boltzmann Machine is an unsupervised feature learning method.  Unlike other feature learning and sparse coding models, a GBM learns features that encode relations between observations rather than observations themselves. 

This module contains two GBM-classes: 
GbmfactoredBinBin implements a GBM with binary hidden units and binary observable units. 
GbmfactoredBinGauss implements a GBM with binary hidden and Gaussian observable units. 
(They are both derived from the base class "Gbmfactored"). 

After instantiating a model, use the method train() to train it (on pairs of numpy arrays, with examples stacked row-wise). Use freeenergy() to compute the unnormalized probability that the model assigns to new data. There are various other methods, see code for details. 

DEPENDENCIES:
This module uses the freely available Python "cudamat"-package to access gpus. If no cuda-gpu is available, one can use the (also freely available) npmat package instead. 


For details on the model see:  
Susskind, J., Memisevic, R., Hinton, G., Pollefeys, M., 2011
"Modeling the joint density of two images under a variety of transformations"
Computer Vision and Pattern Recognition (CVPR) 2011. 

"""


import numpy
import numpy.random
import cPickle

import cudamat as cm
#CUDAMAT WILL USUALLY BE INITIALIZED OUTSIDE, USING THE FOLLOWING CODE:
#cm.cuda_set_device(<DEVICE>)
#cm.init()
#cm.CUDAMatrix.init_random(<SEED>)

SMALL = 0.0000001


def logsumexp(x, dim=-1):
    if len(x.shape) < 2:
        xmax = x.max()
        return xmax + numpy.log(numpy.sum(numpy.exp(x-xmax)))
    else:
        if dim != -1:
            x = x.transpose(range(dim) + range(dim+1, len(x.shape)) + [dim])
        lastdim = len(x.shape)-1
        xmax = x.max(lastdim)
        return xmax + numpy.log(numpy.sum(numpy.exp(x-xmax[...,numpy.newaxis]),lastdim))


class Gbmfactored(object):

    def __init__(self, numin, numout, numfac, nummap, batchsize, initrange=0.001, stepsize=0.0001, momentum=0.9, normfilters=True, normvis=0.01, normmap=0.01):
        self.numin  = numin
        self.numout = numout
        self.numfac = numfac
        self.nummap = nummap
        self.batchsize = batchsize
        self.normfilters = normfilters
        self.initrange = initrange

        #OPTIMIZATION SPECIFIC SETTINGS
        self.stepsize = stepsize
        self.momentum = momentum

        #PARAMS AND INCS
        self.Wxf = cm.CUDAMatrix(self.initrange*numpy.random.randn(numin, numfac))
        self.Wyf = cm.CUDAMatrix(self.initrange*numpy.random.randn(numout, numfac))
        self.Wzf = cm.CUDAMatrix(self.initrange*numpy.random.randn(nummap, numfac))
        self.Wz = cm.CUDAMatrix(numpy.zeros((1, nummap), 'single'))
        self.Wx = cm.CUDAMatrix(numpy.zeros((1, numin), 'single'))
        self.Wy = cm.CUDAMatrix(numpy.zeros((1, numout), 'single'))
        self.Wxf_inc = cm.CUDAMatrix(numpy.zeros((numin, numfac), 'single'))
        self.Wyf_inc = cm.CUDAMatrix(numpy.zeros((numout, numfac), 'single'))
        self.Wzf_inc = cm.CUDAMatrix(numpy.zeros((nummap, numfac), 'single'))
        self.Wz_inc = cm.CUDAMatrix(numpy.zeros((1, nummap), 'single'))
        self.Wx_inc = cm.CUDAMatrix(numpy.zeros((1, numin), 'single'))
        self.Wy_inc = cm.CUDAMatrix(numpy.zeros((1, numout), 'single'))

    	self._Wxf = cm.CUDAMatrix(self.initrange*numpy.random.randn(numin, numfac))
    	self._Wyf = cm.CUDAMatrix(self.initrange*numpy.random.randn(numout, numfac))
    	self._Wzf = cm.CUDAMatrix(self.initrange*numpy.random.randn(nummap, numfac))
    	self._nWxf = cm.CUDAMatrix(cm.reformat(numpy.zeros((1, numfac), 'single')))
        self._nWyf = cm.CUDAMatrix(cm.reformat(numpy.zeros((1, numfac), 'single')))
    	self._nWzf = cm.CUDAMatrix(cm.reformat(numpy.zeros((1, numfac), 'single')))
    	self.normvis = normvis #self.initrange*sqrt(numout)
    	self.normmap = normmap #self.initrange*sqrt(nummap) #1.0
        self.normalize_visfilters()
        self.normalize_mapfilters()

        #GIBBS CACHE 
        self.z_probs = cm.CUDAMatrix(cm.reformat(numpy.zeros((batchsize, nummap), 'single')))
        self.x_probs = cm.CUDAMatrix(cm.reformat(numpy.zeros((batchsize, numin), 'single')))
        self.y_probs = cm.CUDAMatrix(cm.reformat(numpy.zeros((batchsize, numout), 'single')))
        self.actsh = cm.CUDAMatrix(cm.reformat(numpy.zeros((batchsize, numfac), 'single')))
        self.actsx = cm.CUDAMatrix(cm.reformat(numpy.zeros((batchsize, numfac), 'single')))
        self.actsy = cm.CUDAMatrix(cm.reformat(numpy.zeros((batchsize, numfac), 'single')))
        self.actsxy = cm.CUDAMatrix(cm.reformat(numpy.zeros((batchsize, numfac), 'single')))
        self.actsxh = cm.CUDAMatrix(cm.reformat(numpy.zeros((batchsize, numfac), 'single')))
        self.actsyh = cm.CUDAMatrix(cm.reformat(numpy.zeros((batchsize, numfac), 'single')))
        self.z_rand = cm.CUDAMatrix(cm.reformat(numpy.zeros((batchsize, nummap), 'single')))
    	self.in_rand = cm.CUDAMatrix(cm.reformat(numpy.zeros((batchsize, numin), 'single')))
    	self.out_rand = cm.CUDAMatrix(cm.reformat(numpy.zeros((batchsize, numout), 'single')))

    def pos_phase(self, X, Y, samplehid=1):
        """
        Performs Gibbs sampling to obtain positive statistics.
        """
        cm.dot(X, self.Wxf, self.actsx)
        cm.dot(Y, self.Wyf, self.actsy)
        self.infer_hids(X, Y, samplehid=samplehid)

    def pos_update(self, X, Y):        
        """
        Apply positive updates.
        """
        self.Wzf_inc.add_dot(self.z_probs.T, self.actsxy)
       	self.Wxf_inc.add_dot(X.T, self.actsyh)
       	self.Wyf_inc.add_dot(Y.T, self.actsxh)
       	self.Wz_inc.add_sums(self.z_probs, axis=0)
       	self.Wx_inc.add_sums(X, axis=0)
       	self.Wy_inc.add_sums(Y, axis=0)

    def neg_phase(self, samplevis=0): #, method='cd'):
        """
        Performs 3-way alternating Gibbs sampling to obtain negative statistics.
        """
        #if method == 'pcd': 
        #    self.pos_phase(self.x_probs, self.y_probs)

        # first decide randomly whether to reconstruct x or y first
       	reconx1 = (numpy.random.rand > .5)*1; recony1 = 1 - reconx1
        reconx2 = 1 - reconx1; recony2 = 1 - recony1
        # then do alternating gibbs for reconstruction and again infer hids
        self.infer_obs(reconx=reconx1, recony=recony1, samplevis=samplevis)
        self.infer_obs(reconx=reconx2, recony=recony2, samplevis=samplevis)
        self.infer_hids(self.x_probs, self.y_probs, samplehid=0)

    def neg_update(self):
        """ 
        Apply negative updates.
        """
        self.Wzf_inc.subtract_dot(self.z_probs.T, self.actsxy)
        self.Wxf_inc.subtract_dot(self.x_probs.T, self.actsyh)
        self.Wyf_inc.subtract_dot(self.y_probs.T, self.actsxh)
        self.Wz_inc.add_sums(self.z_probs, axis=0, mult=-1.0)
        self.Wx_inc.add_sums(self.x_probs, axis=0, mult=-1.0)
        self.Wy_inc.add_sums(self.y_probs, axis=0, mult=-1.0)

    def get_hidprobs(self, X, Y):
        numbatches = X.shape[0] / self.batchsize
        hiddens = []
        for batch in range(numbatches + int(numpy.mod(X.shape[0], self.batchsize)>0)):
            hiddens.append(self.get_hidprobs_batch(X[batch*self.batchsize:(batch+1)*self.batchsize],
                                                   Y[batch*self.batchsize:(batch+1)*self.batchsize]))
        return numpy.concatenate(hiddens, 0)

    def get_hidprobs_batch(self, X, Y):
        numcases = X.shape[0]
        if numcases < self.batchsize:
            X = numpy.concatenate((X, numpy.zeros((self.batchsize-X.shape[0], self.numin), "single")), 0)
            Y = numpy.concatenate((Y, numpy.zeros((self.batchsize-Y.shape[0], self.numout), "single")), 0)
        X_ = cm.CUDAMatrix(X)
        Y_ = cm.CUDAMatrix(Y)
        cm.dot(X_, self.Wxf, self.actsx)
        cm.dot(Y_, self.Wyf, self.actsy)
        self.infer_hids(X_, Y_, samplehid=False)
        return self.z_probs.asarray()[:numcases, :].copy()

    def infer_hids(self, X, Y, samplehid=1):
        """
        Computes filter activities and mapping units given a set 
        of corresponding inputs and outputs.
        """
        self.actsxy.assign(self.actsx)
        self.actsxy.mult(self.actsy)
        cm.dot(self.actsxy, self.Wzf.T, self.z_probs)
        self.z_probs.add_row_vec(self.Wz)
        self.z_probs.apply_sigmoid()
        if samplehid == 1:
            self.z_rand.fill_with_rand()
            self.z_probs.greater_than(self.z_rand)
        cm.dot(self.z_probs, self.Wzf, self.actsh)
        self.actsxh.assign(self.actsx)
        self.actsxh.mult(self.actsh)
        self.actsyh.assign(self.actsy)
        self.actsyh.mult(self.actsh)

    def update_weights(self, batchsize, updatevis=1, updatehids=1, weightcost=0.0):
        batchsize = numpy.float(batchsize)
        
        #add weightcost updates
        self.Wzf_inc.add_mult(self.Wzf, -weightcost)
        self.Wxf_inc.add_mult(self.Wxf, -weightcost)
        self.Wyf_inc.add_mult(self.Wyf, -weightcost)

        # update weights
        if updatehids:
            self.Wzf.add_mult(self.Wzf_inc, 0.1*self.stepsize/batchsize)
            self.Wz.add_mult(self.Wz_inc, 0.1*self.stepsize/batchsize)
        if updatevis:
            self.Wxf.add_mult(self.Wxf_inc, self.stepsize/batchsize)
            self.Wyf.add_mult(self.Wyf_inc, self.stepsize/batchsize)
            self.Wx.add_mult(self.Wx_inc, 0.1*self.stepsize/batchsize)
            self.Wy.add_mult(self.Wy_inc, 0.1*self.stepsize/batchsize)

    def get_mfrecon(self, X, Y):
        self.pos_phase(X, Y, samplehid=0)
        self.infer_obs(samplevis=False)
        err_x, err_y = self.get_reconsse(X, Y)
        return err_x, err_y

    def get_reconsse(self, X=None, Y=None):
    	err_x = numpy.array(0.0)
    	err_y = numpy.array(0.0)
    	if X is not None:
    	    err_x = numpy.sum((X.asarray()-self.x_probs.asarray())**2,1)
        if Y is not None:
            err_y = numpy.sum((Y.asarray()-self.y_probs.asarray())**2,1)
    	return err_x, err_y

    def normalize_visfilters(self):
        """	
        normalize filters to have equal norms based on running average. 
        if normfilters==false, then this will just update the variables that 
        track the weight norms without actually doing normalization.
        """
        self._Wxf.assign(self.Wxf)
        self._Wxf.mult(self._Wxf)
        self._Wyf.assign(self.Wyf)
        self._Wyf.mult(self._Wyf)
        self._nWxf = cm.sqrt(self._Wxf.sum(0)).add(SMALL)
        self._nWyf = cm.sqrt(self._Wyf.sum(0)).add(SMALL)
        normvis = (self._nWxf.asarray().mean() + self._nWyf.asarray().mean())/2.0
        if self.normfilters == True:
            self.normvis = .95 * self.normvis + .05 * numpy.float(normvis)
            self.Wxf.mult_by_row(self._nWxf.reciprocal().mult(self.normvis))
            self.Wyf.mult_by_row(self._nWyf.reciprocal().mult(self.normvis))
        else:
            self.normvis = normvis

    def normalize_mapfilters(self):
        """     
        normalize filters to have equal norms based on running average. 
        if normfilters==false, then this will just update the variables that 
        track the weight norms without actually doing normalization.
        """
        self._Wzf.assign(self.Wzf)
        self._Wzf.mult(self._Wzf)
        self._nWzf = cm.sqrt(self._Wzf.sum(0)).add(SMALL)
        normmap = self._nWzf.asarray().mean()
        if self.normfilters == True:
            self.normmap = .95 * self.normmap + .05 * numpy.float(normmap)
            self.Wzf.mult_by_row(self._nWzf.reciprocal().mult(self.normmap))
        else:
            self.normmap = normmap

    def apply_momentum(self):
       	self.Wzf_inc.mult(self.momentum)
       	self.Wxf_inc.mult(self.momentum)
       	self.Wyf_inc.mult(self.momentum)
       	self.Wz_inc.mult(self.momentum)
       	self.Wx_inc.mult(self.momentum)
       	self.Wy_inc.mult(self.momentum)

    def train(self, x, y, numsteps=1, weightcost=0.0, verbose=False):
        if type(x) == numpy.ndarray:
            x = cm.CUDAMatrix(x)
        if type(y) == numpy.ndarray:
            y = cm.CUDAMatrix(y)
        batchsize = numpy.float(y.shape[0])
        for step in range(numsteps):
            self.apply_momentum()

            # positive phase
            self.pos_phase(x, y, samplehid=1)
            self.pos_update(x, y)

            # negative phase
            self.neg_phase(samplevis=0) #, method='cd')
            self.neg_update()

            # get current cost
            errx, erry = self.get_reconsse(x, y)
            err = (sum(errx)+sum(erry))/2.0

            # update weights
            self.update_weights(batchsize, updatevis=True, updatehids=True, weightcost=weightcost)

            # normalize filters
            self.normalize_visfilters()
            self.normalize_mapfilters()

            # print performance 
            if verbose: 
                print "step %d: errtrain=%.3f, n_vis=%.3f, n_map=%.3f" % (step, err, self.normvis, self.normmap)

    def save(self, filename):
        allparams = dict( Wxf=self.Wxf.asarray(),
                          Wyf=self.Wyf.asarray(),
                          Wzf=self.Wzf.asarray(),
                          Wz=self.Wz.asarray(),
                          Wx=self.Wx.asarray(),
                          Wy=self.Wy.asarray(),
                          Wxf_inc=self.Wxf_inc.asarray(),
                          Wyf_inc=self.Wyf_inc.asarray(),
                          Wzf_inc=self.Wzf_inc.asarray(),
                          Wz_inc=self.Wz_inc.asarray(),
                          Wx_inc=self.Wx_inc.asarray(),
                          Wy_inc=self.Wy_inc.asarray(),
    	                  normvis=self.normvis ,
    	                  normmap=self.normmap )
        cPickle.dump(allparams, file(filename, 'w'))

    def load(self, filename):
        allparams = cPickle.load(file(filename, 'r'))
        self.Wxf.assign(cm.CUDAMatrix(allparams['Wxf']))
        self.Wyf.assign(cm.CUDAMatrix(allparams['Wyf']))
        self.Wzf.assign(cm.CUDAMatrix(allparams['Wzf']))
        self.Wz.assign(cm.CUDAMatrix(allparams['Wz']))
        self.Wx.assign(cm.CUDAMatrix(allparams['Wx']))
        self.Wy.assign(cm.CUDAMatrix(allparams['Wy']))
        self.Wxf_inc.assign(cm.CUDAMatrix(allparams['Wxf_inc']))
        self.Wyf_inc.assign(cm.CUDAMatrix(allparams['Wyf_inc']))
        self.Wzf_inc.assign(cm.CUDAMatrix(allparams['Wzf_inc']))
        self.Wz_inc.assign(cm.CUDAMatrix(allparams['Wz_inc']))
        self.Wx_inc.assign(cm.CUDAMatrix(allparams['Wx_inc']))
        self.Wy_inc.assign(cm.CUDAMatrix(allparams['Wy_inc']))
    	self.normvis = allparams['normvis'] 
    	self.normmap = allparams['normmap']

    def freeenergy_metric_normalized(self, X, Y):
        """ Normalized free energy metric. Assumes that methods 
            freeenergy and freeenergy_metric are defined (in the subclass).
        """
        return self.freeenergy_metric(X, Y)\
                +self.freeenergy_metric(Y, X).T\
                -self.freeenergy(X, X)[:,numpy.newaxis]\
                -self.freeenergy(Y, Y)[numpy.newaxis :]


class GbmfactoredBinGauss(Gbmfactored):

    def infer_obs(self, reconx=1, recony=1, samplevis=0):
        # reconstruct output images
        if recony == 1:
            self.actsxh.assign(self.actsh)
            self.actsxh.mult(self.actsx)
            cm.dot(self.actsxh, self.Wyf.T, self.y_probs)
            self.y_probs.add_row_vec(self.Wy)
            if samplevis == 1:
                self.out_rand.fill_with_randn()
                self.y_probs.add(self.out_rand)
            cm.dot(self.y_probs, self.Wyf, self.actsy)

        # reconstruct input images
        if reconx == 1:
            self.actsyh.assign(self.actsh)
            self.actsyh.mult(self.actsy)
            cm.dot(self.actsyh, self.Wxf.T, self.x_probs)
            self.x_probs.add_row_vec(self.Wx)
            if samplevis == 1:
                self.in_rand.fill_with_randn()
                self.x_probs.add(self.in_rand)
            cm.dot(self.x_probs, self.Wxf, self.actsx)

    def freeenergy(self, X, Y):
        """
        Compute free energy for the vectors in (corresponding) columns of X and Y. 
        """
        numin, numcases = X.shape
        nummap, numfac = self.Wzf.shape
        gX = cm.CUDAMatrix(X.T)
        gY = cm.CUDAMatrix(Y.T)
        factorsXY = cm.dot(gX, self.Wxf)
        factorsY = cm.dot(gY, self.Wyf)
        factorsXY.mult(factorsY)
        F = numpy.sum( logsumexp( numpy.concatenate(
                   (numpy.zeros((nummap, numcases, 1), 'single'),
                   (cm.dot(factorsXY, self.Wzf.T).add_row_vec(self.Wz)).\
                                                      asarray().T[:,:,numpy.newaxis]) , 2)
                                , 2), 0)
        F -= 0.5 * ((gX.asarray()-self.Wx.asarray())**2).sum(1)
        F -= 0.5 * ((gY.asarray()-self.Wy.asarray())**2).sum(1)
        return -F

    def freeenergy_metric(self, X, Y):
        """
        Compute matrix of free energy values between _every_ element (column) of 
        X and every element of Y. 
        """
        numin, numcasesX = X.shape
        numin, numcasesY = Y.shape
        if numcasesX == 1:  #deal with cudamat bug on one-dimensional inputs
            X = numpy.concatenate((X, X), 1)
            numcasesX = 2
            onein = True
        else:
            onein = False
        nummap, numfac = self.Wzf.shape
        factorsXY = cm.empty((numcasesY, numfac))
        gX = cm.CUDAMatrix(X.T)
        gY = cm.CUDAMatrix(Y.T)
        gY_ = cm.CUDAMatrix(Y.T)
        factorsX = cm.dot(gX, self.Wxf)
        factorsY = cm.dot(gY, self.Wyf)
        FXslice = cm.CUDAMatrix(numpy.zeros((1, numfac), 'single'))
        TX = cm.CUDAMatrix(numpy.zeros((1, numin), 'single'))
        D = cm.CUDAMatrix(numpy.zeros((numcasesY, numcasesX), 'single'))
        M = cm.CUDAMatrix(numpy.zeros((numcasesY, nummap), 'single'))
        Mz = cm.CUDAMatrix(numpy.zeros((numcasesY, nummap), 'single'))
        Mgreater0 = cm.CUDAMatrix(numpy.zeros((numcasesY, nummap), 'single'))
        for i in range(numcasesX):
            factorsXY.assign(factorsY)
            factorsX.get_row_slice(i, i+1, target=FXslice) 
            factorsXY.mult_by_row(FXslice) 
            cm.dot(factorsXY, self.Wzf.T, target=M)
            M.add_row_vec(self.Wz)
            M.greater_than(0.0, target=Mgreater0)
            Mgreater0.mult(M)
            M.subtract(Mgreater0)
            cm.exp(M)
            Mz.assign(0.0)
            Mz.subtract(Mgreater0)
            cm.exp(Mz)
            M.add(Mz)
            cm.log(M)
            M.add(Mgreater0)
            D.set_col_slice(i, i+1, M.sum(1))
            gX.get_row_slice(i, i+1, target=TX)
            TX.subtract(self.Wx)
            TX.mult(TX)
            TX.mult(-0.5)
            D.get_col_slice(i, i+1).add_row_vec(TX.sum(1))
            gY_.assign(gY)
            gY_.mult(-1.0)
            gY_.add_row_vec(self.Wy)
            gY_.mult(gY_)
            gY_.mult(-0.5)
            D.get_col_slice(i, i+1).add(gY_.sum(1))
        if not onein:
            return -D.asarray().T
        else:
            return -D.asarray().T[0,:][numpy.newaxis,:]


class GbmfactoredBinBin(Gbmfactored):

    def infer_obs(self, reconx=1, recony=1, samplevis=0):
        # reconstruct output images
        if recony == 1:
            self.actsxh.assign(self.actsh)
            self.actsxh.mult(self.actsx)
            cm.dot(self.actsxh, self.Wyf.T, self.y_probs)
            self.y_probs.add_row_vec(self.Wy)
            self.y_probs.apply_sigmoid()
            if samplevis == 1:
                self.out_rand.fill_with_rand()
                self.y_probs.greater_than(self.out_rand)
            cm.dot(self.y_probs, self.Wyf, self.actsy)

        # reconstruct input images
        if reconx == 1:
            self.actsyh.assign(self.actsh)
            self.actsyh.mult(self.actsy)
            cm.dot(self.actsyh, self.Wxf.T, self.x_probs)
            self.x_probs.add_row_vec(self.Wx)
            self.x_probs.apply_sigmoid()
            if samplevis == 1:
                self.in_rand.fill_with_rand()
                self.x_probs.greater_than(self.in_rand)
            cm.dot(self.x_probs, self.Wxf, self.actsx)

    def freeenergy(self, X, Y):
        """
        Compute free energy for the vectors in (corresponding) columns of X and Y. 
        """
        numin, numcases = X.shape
        nummap, numfac = self.Wzf.shape
        gX = cm.CUDAMatrix(X.T)
        gY = cm.CUDAMatrix(Y.T)
        factorsXY = cm.dot(gX, self.Wxf)
        factorsY = cm.dot(gY, self.Wyf)
        factorsXY.mult(factorsY)
        F = numpy.sum( logsumexp( numpy.concatenate(
                   (numpy.zeros((nummap, numcases, 1), 'single'),
                   (cm.dot(factorsXY, self.Wzf.T).add_row_vec(self.Wz)).\
                                                      asarray().T[:,:,numpy.newaxis]) , 2)
                                , 2), 0)
        F += cm.dot(gX, self.Wx.T).asarray().flatten()
        F += cm.dot(gY, self.Wy.T).asarray().flatten()
        return -F

    def freeenergy_metric(self, X, Y):
        """
        Compute matrix of free energy values between _every_ element (column) of 
        X and every element of Y. 
        """
        numin, numcasesX = X.shape
        numin, numcasesY = Y.shape
        if numcasesX == 1:  #deal with cudamat bug on one-dimensional inputs
            X = numpy.concatenate((X, X), 1)
            numcasesX = 2
            onein = True
        else:
            onein = False
        nummap, numfac = self.Wzf.shape
        factorsXY = cm.empty((numcasesY, numfac))
        gX = cm.CUDAMatrix(X.T)
        gY = cm.CUDAMatrix(Y.T)
        gY_ = cm.CUDAMatrix(Y.T)
        factorsX = cm.dot(gX, self.Wxf)
        factorsY = cm.dot(gY, self.Wyf)
        FXslice = cm.CUDAMatrix(numpy.zeros((1, numfac), 'single'))
        TX = cm.CUDAMatrix(numpy.zeros((1, 1), 'single'))
        D = cm.CUDAMatrix(numpy.zeros((numcasesY, numcasesX), 'single'))
        M = cm.CUDAMatrix(numpy.zeros((numcasesY, nummap), 'single'))
        Mz = cm.CUDAMatrix(numpy.zeros((numcasesY, nummap), 'single'))
        Mgreater0 = cm.CUDAMatrix(numpy.zeros((numcasesY, nummap), 'single'))
        for i in range(numcasesX):
            factorsXY.assign(factorsY)
            factorsX.get_row_slice(i, i+1, target=FXslice) 
            factorsXY.mult_by_row(FXslice) 
            cm.dot(factorsXY, self.Wzf.T, target=M)
            M.add_row_vec(self.Wz)
            M.greater_than(0.0, target=Mgreater0)
            Mgreater0.mult(M)
            M.subtract(Mgreater0)
            cm.exp(M)
            Mz.assign(0.0)
            Mz.subtract(Mgreater0)
            cm.exp(Mz)
            M.add(Mz)
            cm.log(M)
            M.add(Mgreater0)
            D.set_col_slice(i, i+1, M.sum(1))
            cm.dot(gX.get_row_slice(i, i+1),self.Wx.T, target=TX)
            D.get_col_slice(i, i+1).add_row_vec(TX)
            D.get_col_slice(i, i+1).add_dot(gY, self.Wy.T)
        if not onein:
            return -D.asarray().T
        else:
            return -D.asarray().T[0,:][numpy.newaxis,:]


if __name__ == '__main__':

    import numpy
    import numpy.random
    import cudamat as cm
    cm.cuda_set_device(0)
    cm.init()
    cm.CUDAMatrix.init_random(1)

    numin = 256
    numout = 256
    numfac = 256
    nummap = 100
    batchsize = 100

    model = GbmfactoredBinGauss(numin, numout, numfac, nummap, batchsize, normfilters=True)
    model.train(numpy.random.randn(batchsize, numin), numpy.random.randn(batchsize, numin), 10, 0.0)



