next up previous
Next: Optimal binary search trees Up: November lecture summary Previous: Matrix chains


November 20

Let's write the algorithm using the convention that we begin counting from 0, so we consider matrices $ A_0$ through $ A_{n-1}$, $ A_0$ has dimensions $ s_0\times s_1$, $ A_1$ has dimensions $ s_1\times s_2$, etc. The most natural approach would be a recursive function (supposing we have an array with $ s_0$, ...$ s_n$):

int rec_matrix_product(int i, int j, int *s) {
    if (i == j) {
       return 0;
    } else {
       int C = rec_matrix_product(i, i, s) +
               rec_matrix_product(i+1, j, s) +
               s[i]*s[i+1]*s[j+1];
       for (int k = i+1; k < j; k++) {
          int c = rec_matrix_product(i, k, s) +
                  rec_matrix_product(k+1, j, s) +
                  s[i]*s[k+1]*s[j+1];
          if (c < C)   C = c;
       } 
       return C;
   } 
}
If we solve the recurrence for the running time of the above code, we find that $ T(n)$ is still exponential, so we haven't made any progress from $ 4^n$! We need to solve this problem ``bottom-up.'' If we start along the diagonal where $ i==j$, the problem is easy, since $ C[i][i]$ is zero. Then we work on the diagonal where $ j==i+1$, and so on:
for (int i = 0; i < n; i++)
    C[i][i] = 0;
for (int len = 1; len < n; len++) {
    for (int i = 0; i < n - len; i++) {
       int j = i + len;
       C[i][j] = C[i][i] + C[i+1][j] + s[i]*s[i+1]*s[j+1];
       for (int k = i+1; k < j; k++) {
          if (C[i][k] + C[k+1][j] + s[i]*s[k+1]*s[j+1] < C[i][j])
             C[i][j] = C[i][k] + C[k+1][j] + s[i]*s[k+1]*s[j+1];
       } 
    } 
}

Now we know the cost of a best grouping, but we need to recover grouping itself. You could work backwards from the optimal cost, and find two subcosts that would yield that number, but it's probably more straightforward just to keep track of the best groupings at each step:

for (int i = 0; i < n; i++) {
    C[i][i] = 0;
    B[i][i] = i;
}
for (int len = 1; len < n; len++) {
    for (int i = 0; i < n-len; i++) {
        int j = i + len;
        C[i][j] = C[i][i] + C[i+1][j] + s[i]*s[i+1]*s[j+1];
        B[i][j] = i;
        for (int k = i+1; k < j; k++) {
           if (C[i][k] + C[k+1][j] + s[i]*s[k+1]*s[j+1] < C[i][j]) {
              C[i][j] = C[i][k] + C[k+1][j] + s[i]*s[k+1]*s[j+1];
              B[i][j] = k;
           }
        }
    }
}



Subsections

Danny Heap 2002-12-16