无监督学习

发布于:2024-05-12 ⋅ 阅读:(153) ⋅ 点赞:(0)

【实验】代码实现

In [4]:

from numpy import *

import matplotlib.pyplot as plt

"""

函数说明:加载数据集

parameters:

    fileName -文件名

return:

    dataMat -数据列表

"""

def loadDataSet(fileName):      

    dataMat = []                

    fr = open(fileName)

    for line in fr.readlines():

        curLine = line.strip().split('\t')

        fltLine = list(map(float,curLine))  #将数据转换为float型数据

        dataMat.append(fltLine)

    return dataMat

"""

函数说明:计算向量欧氏距离

parameters:

    vecA -向量A

    vecB -向量B

return:

    欧氏距离

"""

def distEclud(vecA, vecB):

    return sqrt(sum(power(vecA - vecB, 2)))  #此处也可以使用其他距离计算公式

"""

函数说明:为给定数据集构建一个包含k个随机质心的集合

parameters:

    dataSet -数据集

    k -质心个数

return:

    centroids -质心列表

"""

def randCent(dataSet, k):

    n = shape(dataSet)[1]

    centroids = mat(zeros((k,n)))                   #创建存储质心的矩阵,初始化为0

    for j in range(n):                              #随机质心必须再整个数据集的边界之内

        minJ = min(dataSet[:,j])

        rangeJ = float(max(dataSet[:,j]) - minJ)    #通过找到数据集每一维的最小和最大值

        centroids[:,j] = mat(minJ + rangeJ * random.rand(k,1)) #生成0到1之间的随机数,确保质心落在边界之内

    return centroids

"""

函数说明:K-均值算法

parameters:

    dataSet -数据集

    k -簇个数

    distMeas -距离计算函数

    createCent -创建初始质心函数

return:

    centroids -质心列表

    clusterAssment -簇分配结果矩阵

"""

def kMeans(dataSet, k, distMeas=distEclud, createCent=randCent):

    m = shape(dataSet)[0]                                #确定数据集中数据点的总数

    clusterAssment = mat(zeros((m,2)))                   #创建矩阵来存储每个点的簇分配结果

    #第一列记录簇索引值,第二列存储误差

    centroids = createCent(dataSet, k)                   #创建初始质心

    clusterChanged = True                                #标志变量,若为True,则继续迭代

    while clusterChanged:

        clusterChanged = False 

        for i in range(m):                               #遍历所有数据找到距离每个点最近的质心

            minDist = inf; minIndex = -1    

            for j in range(k):                           #遍历所有质心

                distJI = distMeas(centroids[j,:],dataSet[i,:])              #计算质心与数据点之间的距离

                if distJI < minDist:    

                    minDist = distJI; minIndex = j

            if clusterAssment[i,0] != minIndex: clusterChanged = True

            clusterAssment[i,:] = minIndex,minDist**2    #将数据点分配到距其最近的簇,并保存距离平方和

        print(centroids)    

        for cent in range(k):                            #对每一个簇

            ptsInClust = dataSet[nonzero(clusterAssment[:,0].A==cent)[0]]   #得到该簇中所有点的值

            centroids[cent,:] = mean(ptsInClust, axis=0) #计算所有点的均值并更新为质心

    return centroids, clusterAssment

"""

函数说明:绘图

parameters:

    centList -质心列表

    myNewAssments -簇列表

    dataMat -数据集

    k -簇个数

return:

    null

"""

def drawDataSet(dataMat,centList,myNewAssments,k):

    fig = plt.figure()      

    rect=[0.1,0.1,0.8,0.8]                                             #绘制矩形

    scatterMarkers=['s', 'o', '^', '8', 'p', 'd', 'v', 'h', '>', '<']  #构建标记形状的列表用于绘制散点图

    ax1=fig.add_axes(rect, label='ax1', frameon=False)

    for i in range(k):                                                 #遍历每个簇

        ptsInCurrCluster = dataMat[nonzero(myNewAssments[:,0].A==i)[0],:]

        markerStyle = scatterMarkers[i % len(scatterMarkers)]          #使用索引来选择标记形状

        ax1.scatter(ptsInCurrCluster[:,0].flatten().A[0], ptsInCurrCluster[:,1].flatten().A[0], marker=markerStyle, s=90)

    ax1.scatter(centList[:,0].flatten().A[0], centList[:,1].flatten().A[0], marker='+', s=300)    #使用"+"来标记质心

    plt.show()

    

if __name__ =='__main__':

    dataMat = mat(loadDataSet('kmeans_algo/testSet.txt'))

    centList,myNewAssments = kMeans(dataMat,4)

    print(centList)

    drawDataSet(dataMat,centList,myNewAssments,4)

[[ 3.8861344   4.08994234]

 [-1.72587401  2.38635025]

 [-4.61995805 -1.81639953]

 [ 2.09463176 -2.82863577]]

[[ 2.71358074  3.11839563]

 [-2.29801424  2.79388557]

 [-3.53973889 -2.89384326]

 [ 2.65077367 -2.79019029]]

[[ 2.6265299   3.10868015]

 [-2.46154315  2.78737555]

 [-3.53973889 -2.89384326]

 [ 2.65077367 -2.79019029]]

[[ 2.6265299   3.10868015]

 [-2.46154315  2.78737555]

 [-3.53973889 -2.89384326]

 [ 2.65077367 -2.79019029]]


网站公告

今日签到

点亮在社区的每一天
去签到