K-means 的算法步骤为:

  1. 选择初始化的 \(k\)个样本作为初始聚类中心: \(a=a_1,a_2, a_3,...,a_k\)
  2. 针对数据集中每个样本\(x_i\),计算它到 \(k\) 个聚类中心的距离并将其分到距离最小的聚类中心所对应的类中;
  3. 针对每个聚类类别\(a_j\),重新计算它的聚类中心,\(a_j=\frac{1}{|c_i|} \sum_{x\in {c_i}}x\), 即属于该类的所有样本的质心
  4. 重复上面 2、3 两步操作,直到达到某个中止条件(迭代次数、最小误差变化等)。

C++代码实现

struct cluster {
    vector<double> center;
    vector<int> samples;
};

double cal_distance(vector<double> p1, vector<double> p2) {
    double dis = 0;
    for (size_t i = 0; i < p1.size(); i++)
    {
        dis += pow((p1[i] - p2[i]), 2);
    }
    return pow(dis, 0.5);
}
vector<cluster> kmeans(vector<vector<double>>& trainX, int k, int max_epoch) {
    int rows = trainX.size();
    vector<cluster> clusters(k);
    for (size_t i = 0; i < k; i++)
    {
        int c = rand() % rows;
        clusters[i].center = trainX[c];
    }
    for (size_t i = 0; i < max_epoch; i++)
    {
        for (size_t j = 0; j < k; j++)
        {
            clusters[i].samples.clear();
        }
        for (size_t j = 0; j < rows; j++)
        {
            int c = 0;
            double min_dis = cal_distance(trainX[j], clusters[c].center);
            for (size_t m = 1; m < k; m++)
            {
                double dis = cal_distance(trainX[j], clusters[m].center);
                c = dis < min_dis ? m : c;
                min_dis = min(min_dis, dis);
            }
            clusters[c].samples.push_back(j);
        }
        for (size_t m = 0; m < k; m++)
        {
            vector<double> val(trainX[0].size(), 0);
            for (size_t n = 0; n < clusters[m].samples.size(); n++)
            {
                int sample_index = clusters[m].samples[n];
                for (int p = 0; p < trainX[0].size(); p++) {
                    val[p] += trainX[sample_index][p];
                    if (n == clusters[m].samples.size() - 1)
                        clusters[m].center[p] = val[p] / clusters[m].samples.size();
                }
            }
        }
    }
    return clusters;
}

python代码实现

import math
import numpy as np
import random
import matplotlib.pyplot as plt
  
def distance(point1, point2):  # 计算距离(欧几里得距离)

    return np.sqrt(np.sum((point1 - point2) ** 2))

def k_means(data, k, max_iter=100, early_stop=False):
    centers = {}  # 初始聚类中心
    # 初始化,随机选k个样本作为初始聚类中心。 random.sample(): 随机不重复抽取k个值
    n_data = data.shape[0]  # 样本个数
    for idx, i in enumerate(random.sample(range(n_data), k)):
        # idx取值范围[0, k-1],代表第几个聚类中心;  data[i]为随机选取的样本作为聚类中心
        centers[idx] = data[i]
    clusters = {}  # 聚类结果,聚类中心的索引idx -> [样本集合]
    
    # 开始迭代
    for i in range(max_iter):  # 迭代次数
        print("开始第{}次迭代".format(i + 1))
        for j in range(k):  # 初始化为空列表
            clusters[j] = []
        for sample in data:  # 遍历每个样本
            distances = []  # 计算该样本到每个聚类中心的距离 (只会有k个元素)
            for c in centers:  # 遍历每个聚类中心
                # 添加该样本点到聚类中心的距离
                distances.append(distance(sample, centers[c]))
            idx = np.argmin(distances)  # 最小距离的索引
            clusters[idx].append(sample)  # 将该样本添加到第idx个聚类中心
        pre_centers = centers.copy()  # 记录之前的聚类中心点
        for c in clusters.keys():
            # 重新计算中心点(计算该聚类中心的所有样本的均值)
            centers[c] = np.mean(clusters[c], axis=0)
        if early_stop:
            is_convergent = True
            for c in centers:
                if distance(pre_centers[c], centers[c]) > 1e-8:  # 中心点是否变化
                    is_convergent = False
                    break
            if is_convergent == True:
                # 如果新旧聚类中心不变,则迭代停止
                break
    return centers, clusters
  
def predict(p_data, centers):  # 预测新样本点所在的类
    # 计算p_data 到每个聚类中心的距离,然后返回距离最小所在的聚类。
    distances = [distance(p_data, centers[c]) for c in centers]
    return np.argmin(distances)


if __name__ == "__main__":
    x = np.random.randint(0, 10, (200, 2))
    centers, cluster = k_means(x, 3)
    for key, val in centers.items():
        plt.scatter(val[0], val[1], marker="*", s=300)
    colors = ["r", "b", "g"]
    for c in cluster:
        for point in cluster[c]:
            plt.scatter(point[0], point[1], c = colors[c])
    print("done")