# File: kmeans2.py
# Desc: make an example of how to cluster objects using K-means algorithm
# use for multiple dimentional data
# first modification: read the inputs from a file

from math import *
from random import randrange
from string import *

# get list of objects from a file
def openfile ():
    
    # get the file name:
    print "**** Clustering program: ***** "
    filename = raw_input ("Get the data from file : ")
    openf = open (filename, "r")

    # read from the file, data_source[i]= a string of a line in filename 
    data_source = openf.readlines()
    openf.close() 
    # data is a list of a list of values
    data = []

    for i in range (len(data_source)):
        data.insert(i,[])
        # convert the string into number
        k = 0
        for numStr in split(data_source[i]):
            data[i].insert(k, eval(numStr))
            k +=1

    print "Data converted from file: "
    print data

    return(data)
    
# function to do k-means clustering alrgorithm
def cluster (data):
    n_point = len(data)
    # let the user decide the number of clusters:
    n_cluster = input ("How many clusters to group? ")
    print "n_cluster is = ", n_cluster
    
    # decide the centroids by randomly picking up an index number
    # and pick the centroid from the data list with the random index
    ran_centroid = []
    print "Making random centroids: "
    for i in range (n_cluster):
        # make sure this random index number is different from other
        test = 0
        first = 0
        print "ran_centroid[",i,"]"
        while (first ==0 or test ==1):
            test = 0
            print "first =", first, " test =", test, 
            ran_num = randrange (0, n_point)
            print " ran_num =", ran_num
            for j in range (i):
                print "ran_centroid [", j,"]", ran_centroid[j]
                if (ran_num == ran_centroid[j]):
                    test = 1
            first += 1
                    
        ran_centroid.insert(i,ran_num)
        print "Random centroid is : ", data[ran_centroid[i]]
    
    # assign points to centroids by:
    # calculating distance from a point to centroids
    # assign a point to the nearest centroid
    data_group = []
    for i in range (n_point):
        min_dis = dis (data[i], data[ran_centroid[0]])
        
        for j in range (n_cluster):
            dist = dis (data[i], data[ran_centroid[j]])
        
            if (dist <= min_dis):
                min_dis = dist
                data_group.insert (i,j)
        print "data_group[", i, "]: ", data_group[i]
             
    # step2:
    # recalculate the centroids
    lst = []
    centroid = []
    for i in range (n_cluster):
        sum = 0
        num_item = 0
        print "Group of centroid[",i,"]: ","centroid= ", data[ran_centroid[i]]
    
        for j in range (n_point):
            if (data_group[j]== i):
                if (i==0 or len(lst) < num_item+1):
                    lst.insert(num_item, data[j])
                else:
                    lst[num_item] = data[j]
                num_item += 1
                print "value: ",data[j], " of group: ", data_group[j]
        
        centroid_valu=[]

        for k in range (len(lst[0])):
            sum = 0
            for h in range (len(lst)):
                print " h= ", h
                sum += lst[h][k]
            centroid_valu.insert(k, sum/len(lst))

        centroid.insert(i, centroid_valu) 
        # now centroid[i] no longer stores index value of the corresponding
        # item in data list, but the real value of the centroid
        print "New centroid value : ", centroid[i]

    # step 3: reassign until no change is made
    count = 1
    print "while loop here: "
    while (count ==1):
        print "   inside while: "
        # making/assigning points to group
        for i in range (n_point):
            min_dis = dis (data[i], centroid[0])
           
            for j in range (n_cluster):
                dist = dis (data[i], centroid[j])
                if (dist <= min_dis):
                    min_dis = dist
                    data_group.insert (i,j)
            print "data_group[", i, "]: ", data_group[i]

        # recalculate the centroids: centroid is the center,
        # calculate: each cordinate of it = average of all vertices' cordinates
        for i in range (n_cluster):
            sum = 0
            num_item = 0
            print "Group of centroid[",i,"]: ","centroid= ",centroid[i] 
            for j in range (n_point):
                if (data_group[j]== i):
                    if (len(lst) < num_item+1):
                        lst.insert(num_item, data[j])
                    else:
                        lst[num_item]=data[j]
                    num_item +=1
                    print "value: ",data[j], " of group: ", data_group[j]
            centroid_valu =[]
            for k in range (len(lst[0])):
                sum = 0     
                for h in range (len(lst)):
                    sum += lst[h][k]
                centroid_valu.insert(k, sum/len(lst))                 
            centroid_val = centroid_valu
            
            if (centroid_val== centroid[i]):
                count = 0
            else:
                centroid[i] = centroid_val
                print "New centroid value : ", centroid[i]

    # print out the final results
    print "*********Final clustering**********"
    for i in range (n_cluster):
         print "Group of centroid[",i,"]: ","centroid= ",centroid[i]
         for j in range (n_point):
             if (i == data_group[j]):
                 print "Value: ", data[j], "of group: ", data_group[j]
    
# calculate the distance between 2 points in multi dimentional space
# given 2 points: each point= 1D list of cordinates
def dis (a,b):
    # dimention of the data
    dim = len(a)
    # calculate the distance in Euclidean metrics
    # can be upgraded to Minkowski metrics for more complex data
    dis = 0
    for i in range (dim):
        dis += (abs(a[i]-b[i]))**2
    dis = sqrt(dis)
    return (dis)

# find the centroid of a group of points
# given a list of all point: 2D list
# do not use here because cannot pass a list to a function
def centroid (lst):
    # centroid value has n-cordinates
    centroid_valu=[]

    for i in range (len(lst[0])):
        sum = 0
        
        for j in range (len(lst)):
            sum += lst[j][i]
            
        centroid_valu.insert(i, sum/len(lst))
        
    return (centroid_valu)    
    
def main():
    data = openfile ()
    data = cluster(data)
    
main()
    
        
    
        
