'''An assortment of sorting algorithms.'''

import random

'''Classic O(n^2) sorting algorithms.'''

def bubble_sort(items):
    '''Sort items using bubble sort.'''
    n = len(items)
    for i in range(n):
        for j in range(n - 1):
            if items[j] > items[j + 1]:
                # Found an inversion; swap items[j] with items[j + 1]
                items[j], items[j + 1] = items[j + 1], items[j]
                
def insertion_sort(items):
    '''Sort items using insertion sort.'''
    n = len(items)
    for i in range(n):
        j = i
        while j > 0 and items[j] < items[j - 1]:
            # Swap j with its left neighbour, and repeat.
            items[j], items[j - 1], j = items[j - 1], items[j], j - 1
            
def selection_sort(items):
    '''Sort items using selection sort.'''
    n = len(items)
    for i in range(n):
        for j in range(i + 1, n, 1):
            # Find the smallest thing in items[i:] and put it in items[i]
            if items[j] < items[i]:
                items[j], items[i] = items[i], items[j]

def swap_down(heap, i, n):
    '''Swap down in the heap[:n] from position i.'''
    while True:
        left = 2 * i + 1
        right = 2 * i + 2
        small = i
        # Compare heap[i] to children.
        if left < n and heap[small] > heap[left]:
            small = left
        if right < n and heap[small] > heap[right]:
            small = right
        # Stop if heap property is satisfied at position i.
        if small == i:
            break
        # Perform swap, update i.
        heap[i], heap[small], i = heap[small], heap[i], small
        
'''O(n log n) sorting algorithms.'''
        
def heap_sort(items):
    '''Sort items using heap sort.'''
    n = len(items)
    # First, build the heap.
    for i in range(n - 1, -1, -1):
        swap_down(items, i, n)
    # Keep performing "extract_min", putting the value at the "back" of the
    # heap.
    while n > 0:
        items[0], items[n - 1], n = items[n - 1], items[0], n - 1
        swap_down(items, 0, n)
    items.reverse()
    
def merge(list_a, list_b, list_c):
    '''Merge sorted lists list_a and list_b into list_c.  Requires that
    len(list_a) + len(list_b) == len(list_c).'''
    n, m = len(list_a), len(list_b)
    a, b, c = 0, 0, 0
    while a < n and b < m:
        if list_a[a] <= list_b[b]:
            list_c[c], a, c = list_a[a], a + 1, c + 1
        else:
            list_c[c], b, c = list_b[b], b + 1, c + 1
    # Either a == n or b == m.
    while a < n:
        list_c[c], a, c = list_a[a], a + 1, c + 1
    while b < m:
        list_c[c], b, c = list_b[b], b + 1, c + 1
        
def merge_sort(items):
    '''Sort items using merge sort.'''
    n = len(items)
    if n < 2:
        # Already sorted.
        return
    left = items[:n/2]
    right = items[n/2:]
    # Recursively sort left and right halves.
    merge_sort(left)
    merge_sort(right)
    # Now merge the two sorted halves back into items.
    merge(left, right, items)
            
            
def partition(items, a, b, pivot):
    '''Partition items[a:b+1] around pivot such that
        1) Everything in items[a:j] is < pivot
        2) Everything in items[j:b+1] is >= pivot
    Finally, return j.'''
    i, j = a, a
    while i <= b:
        if items[i] < pivot:
            # Move items[i] to, and increase the size of, the first portion of the
            # list.
            items[i], items[j] = items[j], items[i]
            j = j + 1
        i = i + 1
    return j
    
def quick_sort_rec(items, a, b):
    '''Perform quick sort on the sublist items[a:b+1].'''
    if a >= b:
        # Length of items[a:b+1] at most 1, already sorted.
        return
    # Parition using pivot items[b]
    i = partition(items, a, b - 1, items[b])
    # Move A[b] (the pivot) to position i; this is where it should remain
    # forever.
    items[b], items[i] = items[i], items[b]
    # Recursively sort items[a:i], items[i+1:b+1]
    quick_sort_rec(items, a, i - 1)
    quick_sort_rec(items, i + 1, b)
    
            
def quick_sort(items):
    '''Sort items using quick sort.'''
    # Make a call to the recursive helper function, giving it the entire range
    # of items.
    quick_sort_rec(items, 0, len(items) - 1)
            
def test_sort(sort_algorithm):
    '''Test using sort_algorithm on some random permutations.'''
    items = []
    sort_algorithm(items)
    assert items == []

    items = [0]
    sort_algorithm(items)
    assert items == [0]

    for i in range(10):
        items = range(1000)
        random.shuffle(items)
        sort_algorithm(items)
        assert items == range(1000)    
    
if __name__ == '__main__':
    test_sort(bubble_sort)
    test_sort(insertion_sort)
    test_sort(selection_sort)
    test_sort(heap_sort)
    test_sort(merge_sort)
    test_sort(quick_sort)

