程序猿成长之路之数据挖掘篇——聚类算法介绍

发布于:2025-06-25 ⋅ 阅读:(15) ⋅ 点赞:(0)

作为无监督学习算法的基础,学好聚类算法很关键,我之前介绍过kmeans聚类算法,现在系统的介绍一下聚类算法

1. 什么是分类

日常生活中我们会经常见到分类的情况,如家里大扫除时给物品归类,超市货架上商品分类等。分类就是先打标签后归类的行为。

2. 什么是聚类

聚类,顾名思义,是聚集不同类别的方式。和一般的分类不同,分类需要手动打标签,也就是所谓的有监督学习,聚类则无需打标签,会自动根据标签(样本属性)对样本进行区分,所以聚类属于无监督学习。举个例子,线下超市货架上的商品按类别摆放就属于分类而不属于聚类,原因是放置一个新的商品我们需要手动打上标签后摆放到货架;相反的,在线上购物平台,一个用户进行了点击浏览操作后系统自动将其划分为某一类别用户,这个操作就是聚类,因为无需对用户进行打标签就将其归类。聚类的流程如下图所示:在这里插入图片描述

3. 聚类有什么用

聚类在数据挖掘与分析中扮演着重要的角色,通过聚类分析我们可以了解到某一类簇(类别)的共有特性,也可以根据已有或者新的样本创建新的特征标签。此外,聚类在人工智能中也应用广泛,如:聚类可以根据不同用户的市场行为,将客户分成不同类型的群体,方便进行市场分析和后续的精准营销;聚类可以进行文本分析,将类似的语句进行划分,用于实现话术分类、话题发现等任务。

4. 聚类有哪几种

聚类根据不同的方式划分为以下三种:(先主要介绍前两种)

  1. 原型聚类

原型聚类是最简单而且最常用的聚类,通俗的说,原型聚类由一组初始的样本组成(也叫做初始簇),之后通过一系列的迭代划分,形成不同的簇(样本集合)。根据不同的原型表示、不同的迭代方法,可以产生不同的聚类方法。常见的原型聚类有K-means聚类。
优势:原型聚类算法原理简单,容易实现,适用于大规模数据集‌。
劣势:

  1. 需要预先指定簇的数量‌:K-means算法需要事先指定簇的数量K,这在实际应用中可能难以确定‌。
  2. ‌对初始值敏感‌:原型聚类的结果对初始质心的选择非常敏感,不同的初始值可能导致不同的聚类结果。
  3. 容易陷入局部最优‌:由于采用迭代优化方法,原型聚类可能陷入局部最优解,而不是全局最优解。
  4. 对噪声和异常点敏感‌:原型聚类对噪声和异常点较为敏感,可能会影响聚类效果。
  1. 密度聚类

密度聚类是根据数据点密度分布的无监督学习方法,它通过定义密度相连区域形成簇,能识别任意形状的簇并有效处理噪声,常见的密度聚类是DBSCAN算法。
密度聚类中常见的术语:

  1. 邻域半径(ε)‌:划定密度计算的范围。也就是规定距离不超过这个值的为邻域对象。
  2. 最小点数:核心对象邻域半径内所包含对象的最小个数,只有邻域半径内的对象数量超过了这个值才能认定当前对象为核心对象,也就是判定核心对象的阈值。
  3. 核心对象:只有邻域半径内的对象数量超过了这个值才能认定当前对象为核心对象,
  4. 密度直达:对象a的邻域中包含了对象b,那就说对象a和对象b是密度直达。
  5. 密度可达:如果对象a的邻域中不包含对象b,a和b都和某一对象c是密度直达的,那就说对象a和对象b是密度可达。
  6. 簇:密度可达的对象所组成的集合。
    优势:自动识别噪声、支持任意形状簇、无需指定簇数。‌‌
    劣势:对参数敏感,高维或大规模数据效率较低。‌‌
  1. 层次聚类

层次聚类试图在不同层次对数据集进行划分,从而形成树形的聚类结构。数据集划分可采用“自底向上”的聚合策略,也可采用“自顶向下”的分拆策略。

5. K-means聚类算法实现原理

算法步骤如下:

  1. 随机选择k个数据点作为初始的簇中心。(注意,正因为随机选择,所以可能导致不同的初始簇会有不同的聚类结果)
  2. 对于每个数据点,计算其到每个簇中心的距离,将其划分到距离最近的簇中。
  3. 对于每个簇,重新计算其簇中心,即将簇内所有数据点的坐标取平均值。
  4. 重复步骤2和步骤3,直到簇中心不再发生变化或达到预设的迭代次数。
  5. 最终得到k个簇,每个簇内的数据点距离尽可能接近,不同簇间的数据点距离尽可能远。

Java版本代码如下:

package kmeans;

import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.SortedMap;
import java.util.TreeMap;
import java.util.concurrent.ConcurrentHashMap;

/**
 * kmeans聚类工具类
 * @author zygswo
 *
 */
public class KmeansUtils<T> {
	private int initKNodeNb; //kmeans初始几何中心数量
	
	private List<T> trainData; //kmeans训练数据
	
	private DistanceType distanceType;
	
	/**
	 * kmeans构造方法(默认为欧式距离公式)
	 * @param initKNodeNb kmeans初始几何中心数量
	 * @param trainData	训练数据
	 */
	public KmeansUtils(List<T> trainData, int initKNodeNb) {
		this.initKNodeNb = initKNodeNb;
		this.trainData = trainData;
		this.distanceType = DistanceType.EUCLID;
	}
	
	/**
	 * kmeans构造方法(默认为欧式距离公式)
	 * @param initKNodeNb kmeans初始几何中心数量
	 * @param trainData	训练数据
	 * @param distanceType 距离公式
	 */
	public KmeansUtils(List<T> trainData, int initKNodeNb, DistanceType distanceType) {
		this.initKNodeNb = initKNodeNb;
		this.trainData = trainData;
		this.distanceType = distanceType;
	}
	
	/**
	 * kmeans模型训练
	 */
	public void fit(){
		//计算距离
		List<Map<String,Double>> initKNodeDistanceVal = Collections.synchronizedList(
				new ArrayList<>()
		);
		//初始化几何列表
		List<List<T>> resList = Collections.synchronizedList(
				new ArrayList<>()
		);
		if (this.trainData == null || this.trainData.isEmpty()) {
			throw new IllegalArgumentException("训练集为空");
		}
		if (this.initKNodeNb <=0) {
			throw new IllegalArgumentException("几何中心数量小于0");
		}
		if (this.initKNodeNb > this.trainData.size()) {
			throw new IllegalArgumentException("几何中心数量超过数组数量");
		}
		if (this.distanceType == null) {
			throw new IllegalArgumentException("距离类型为空");
		}
		//1.获取前initKNodeNb个数据放入initKNodeList列表中
		//初始化的几何中心,需要选择差异较大的
//		this.trainData.sort((T item1,T item2)-> {
//			return (int)(calcDiff(item1,this.trainData.get(0)) - calcDiff(item2,this.trainData.get(0)));
//		});
		this.trainData = sort(this.trainData);
		int step = this.trainData.size() / initKNodeNb;
		//选择从小到大的initKNodeNb个元素作为初始几何
		for (int i = 0; i < this.trainData.size() && resList.size() < initKNodeNb; i+=step) {
			List<T> temp = Collections.synchronizedList(
					new ArrayList<>()
			);
			temp.add(this.trainData.get(i));
			resList.add(temp); //多个几何列表设置初始结点
		}
//		System.out.println(this.trainData);
//		System.out.println(resList.toString());
		//2.计算所有变量到不同的几何中心距离,如果稳定了(几何中心固定了),就退出循环
		while(true) {
			boolean balanced = true; //是否已经平衡
			for (T item: this.trainData) {
				double distance, minDistance = Double.MAX_VALUE; //求最小距离
				int preIndex = 0,afterIndex = 0; //preIndex-原位置
				initKNodeDistanceVal.clear();
				//计算几何中心
				for (int i = 0; i < initKNodeNb; i++) {
					if (resList.get(i).size() > 0)
						initKNodeDistanceVal.add(calc(resList.get(i))); //计算初始结点距离
				}
				//计算原来的位置
				for (int i = 0; i < initKNodeNb; i++) {
					if(resList.get(i).contains(item)) {
						preIndex = i;
						break;
					}
				}
//				System.out.println("item = " + item.toString());
				//计算不同变量到不同的几何中心距离
				for (int i = 0; i < initKNodeNb; i++) {
					if (resList.get(i).size() > 0 && i < initKNodeDistanceVal.size()) {
						distance = calcDistance(item, initKNodeDistanceVal.get(i));
//						System.out.println("distance = " + distance);
//						System.out.println("minDistance = " + minDistance);
						if (distance < minDistance) {
							minDistance = distance;
							afterIndex = i;
						}
					}					
				}
//				System.out.println("preIndex = " + preIndex);
//				System.out.println("afterIndex = " + afterIndex);
				//位置替换,如果替换就还没结束
				if (preIndex != afterIndex) {
					resList.get(preIndex).remove(item);
					resList.get(afterIndex).add(item);
					balanced = false;
				} 
				//如果preIndex == afterIndex == 0
				if (preIndex == afterIndex) {
					//如果新增就还没结束
					if (!resList.get(preIndex).contains(item)) {
						resList.get(preIndex).add(item);
						balanced = false;
					}
				}
			}
			if (balanced){
				break;
			}
		}
		//打印结果
		for (List<T> list : resList) {
			System.out.println(list.toString());
		}
	}
	
	/**
	 * 排序
	 * @param trainData
	 */
	private List<T> sort(List<T> list) {
		List<T> res = new ArrayList<>();
		Map<Double,List<T>> map = new ConcurrentHashMap<>();
		//计算距离
		for(T item:list) {
			double distance = calcDiff(item,list.get(0));
			if (!map.containsKey(distance)) {
				List<T> arr = new ArrayList<>();
				arr.add(item);
				map.put(distance, arr);
			} else {
				List<T> arr = map.get(distance);
				arr.add(item);
				map.put(distance, arr);
			}
		}
		//按照距离从小到大排列
		SortedMap<Double,List<T>> sortedMap = new TreeMap<>(map);
//		System.out.println(sortedMap.toString());
		for (Double key: sortedMap.keySet()) {
			res.addAll(sortedMap.get(key));
		}
		return res;
	}

	/**
	 * 计算距离
	 * @param item1 item1
	 * @param item2 item2
	 * @return
	 */
	private double calcDiff(T item1, T item2) {
		List<T> list = Collections.synchronizedList(new ArrayList<>());
		list.add(item2);
		Map<String, Double> map = calc(list);
		double dist = calcDistance(item1, map);
//		System.out.println(item1.toString() + "=>" +item2.toString()+"dist = " + dist);
		return dist;
	}
/**
	 * 计算距离
	 * @param item 当前对象
	 * @param map 几何中心
	 * @return
	 */
	private double calcDistance(T item, Map<String, Double> map) {
		double distance = 0.0;//距离
		int level = 0;//根据距离公式判断距离计算等级
		Class<?> cls = item.getClass();
		Field[] fs = cls.getDeclaredFields();
		for (Field f : fs) {
			double dist1 = 0.0, dist2 = 0.0;
			f.setAccessible(true);
			//获取需要计算的参数
			Elem el = f.getAnnotation(Elem.class);
			if (el == null) {
				continue;
			}
			try {
				switch(el.type()) {
				case BASIC: break;
				case XUSHU:
					//获取数组
					String[] arr = el.list();
					if (arr == null) {
						throw new IllegalArgumentException("序数属性需配置属性集合数组");
					}
					//数组排序
					Arrays.sort(arr);
					List<String> list = Arrays.asList(arr);
					//计算差距步长
					Double diffStep = 1 / (list.size() * 1.0);
					//获取当前对象序数属性的值
					Object value = f.get(item);
					dist1 = list.indexOf(value) * diffStep;
					break;
				case NUMBER: 
					//获取当前对象数值属性的值
					Object value1 = f.get(item); 
					//数据转换
					Double intVal = Double.parseDouble(String.valueOf(value1));
					dist1 = intVal;
					break;
				case ERYUAN:
					//获取数组
					String[] arr1 = el.list();
					if (arr1 == null) {
						arr1 = new String[]{"0","1"};
					} else {
						//数组排序
						Arrays.sort(arr1);
					}
					//转列表
					List<String> list1 = Arrays.asList(arr1);
					//计算差距步长
					Double diffStep1 = 1 / (list1.size() * 1.0);
					Object value2 = f.get(item);
					int ind = list1.indexOf(value2);
					dist1 = ind * diffStep1;
					break;
				}
				//获取当前几何中心属性的值
				dist2 = map.get(f.getName());
				//计算距离
				switch(distanceType) {
					case EUCLID: level = 2; break;
					case MANHATTAN: level = 1;break;
					case QIEBIXUEFU: level = 100;break;
				}
				distance += Math.pow(Math.abs(dist1 - dist2),level);
			} catch(Exception ex) {
				throw new RuntimeException(ex.getMessage());
			}
			distance = Math.pow(distance, 1/(level * 1.0));
		}	
		return distance;
	}

	/**
	 * 计算几何中心坐标
	 * @param kNodeList
	 * @return 几何中心坐标map
	 */
	private Map<String, Double> calc(List<T> kNodeList) {
		if (kNodeList == null || kNodeList.size() <= 0) {
			throw new IllegalArgumentException("几何中心列表数组为空");
		}
		//反射获取参数,形成数值数组
		Map<String, Double> result = new ConcurrentHashMap<>();
		T item = kNodeList.get(0);
		Class<?> cls = item.getClass();
		Field[] fs = cls.getDeclaredFields();
		for (Field f: fs) {
			//获取需要计算的参数
			Elem el = f.getAnnotation(Elem.class);
			if (el == null) {
				continue;
			}
			//将数据转换成数值
			Double dist = 0.0;
			switch(el.type()) {
				case BASIC: break;
				case XUSHU: 
					//获取数组
					String[] arr = el.list();
					if (arr == null) {
						throw new IllegalArgumentException("序数属性需配置属性集合数组");
					}
					//数组排序
					Arrays.sort(arr);
					//转列表
					List<String> list = Arrays.asList(arr);
					//计算差距步长
					Double diffStep = 1 / (list.size() * 1.0);
					for (T kNode : kNodeList) {
						try {
							//获取当前对象序数属性的值
							Object value = f.get(kNode);
							int ind = list.indexOf(value);
							//求和
							dist += ind * diffStep;
						} catch (IllegalArgumentException e) {
							// TODO Auto-generated catch block
							e.printStackTrace();
						} catch (IllegalAccessException e) {
							// TODO Auto-generated catch block
							e.printStackTrace();
						}
					}
					break;
				case NUMBER: 
					for (T kNode : kNodeList) {
						try {
							//获取当前对象数值属性的值
							Object value = f.get(kNode);
							//数据转换
							Double intVal = Double.parseDouble(String.valueOf(value));
							dist += intVal;
						} catch (IllegalArgumentException e) {
							// TODO Auto-generated catch block
							e.printStackTrace();
						} catch (IllegalAccessException e) {
							// TODO Auto-generated catch block
							e.printStackTrace();
						}
					}
					break;
				case ERYUAN:
					//获取数组
					String[] arr1 = el.list();
					if (arr1 == null) {
						arr1 = new String[]{"0","1"};
					} else {
						//数组排序
						Arrays.sort(arr1);
					}
					//转列表
					List<String> list1 = Arrays.asList(arr1);
					//计算差距步长
					Double diffStep1 = 1 / (list1.size() * 1.0);
					for (T kNode : kNodeList) {
						try {
							//获取当前对象二元属性的值
							Object value = f.get(kNode);
							int ind = list1.indexOf(value);
							//求和
							dist += ind * diffStep1;
						} catch (IllegalArgumentException e) {
							// TODO Auto-generated catch block
							e.printStackTrace();
						} catch (IllegalAccessException e) {
							// TODO Auto-generated catch block
							e.printStackTrace();
						}
					}
					break;
			}
			dist /= (kNodeList.size() * 1.0); //求平均值
			result.put(f.getName(), dist);
		}
		return result;
	}
	
	public static void main(String[] args) {
		List<Student> trainData = new ArrayList<>();
		trainData.add(new Student("zyl",28,"男"));
		trainData.add(new Student("sjl",28,"女"));
		trainData.add(new Student("xxx",27,"男"));
		trainData.add(new Student("stc",30,"男"));
		trainData.add(new Student("wxq",30,"女"));
		trainData.add(new Student("zzz",27,"男"));
		trainData.add(new Student("sss",27,"女"));
		trainData.add(new Student("mmm",20,"男"));
		trainData.add(new Student("qqq",20,"女"));
		trainData.add(new Student("666",30,"男"));
		trainData.add(new Student("nnn",20,"男"));
		trainData.add(new Student("lll",25,"男"));
		trainData.add(new Student("ppp",25,"女"));
		trainData.add(new Student("aaa",19,"男"));
		trainData.add(new Student("ccc",19,"女"));
		KmeansUtils<Student> utils = new KmeansUtils<>(trainData, 3);
		utils.fit();
	}
}

student类

package kmeans;

import java.util.List;

public class Student{
	@Override
	public String toString() {
		return "Student [name=" + name + ", age=" + age + ", gender=" + gender + ", myHobby=" + myHobby
				+ ", myDream=" + myDream + "]";
	}
	public List<MyHobby> getMyHobby() {
		return myHobby;
	}
	public Student setMyHobby(List<MyHobby> myHobby) {
		this.myHobby = myHobby;
		return this;
	}
	public String getName() {
		return name;
	}
	public Student setName(String name) {
		this.name = name;
		return this;
	}
	public int getAge() {
		return age;
	}
	public Student setAge(int age) {
		this.age = age;
		return this;
	}
	public String getGender() {
		return gender;
	}
	public Student setGender(String gender) {
		this.gender = gender;
		return this;
	}
	String name;
	
	@Elem(type = ElemType.NUMBER)
	int age;
	
	@Elem(type = ElemType.XUSHU,list={"男","女"})
	String gender;
	
	@Elem()
	List<MyHobby> myHobby;
	
	@Elem()
	List<String> myDream;
	
	public Student(String name, int age, String gender) {
		super();
		this.name = name;
		this.age = age;
		this.gender = gender;
	}
	
	public Student(String name, int age, String gender,List<MyHobby> myHobby) {
		this(name,age,gender);
		this.myHobby = myHobby;
	}
	
	public Student(String name, int age, String gender,List<MyHobby> myHobby, List<String> myDreams) {
		this(name,age,gender);
		this.myHobby = myHobby;
		this.myDream = myDreams;
	}
}

distanceType类

public enum DistanceType {
	EUCLID("欧几里得距离"),
	MANHATTAN("曼哈顿距离"),
	QIEBIXUEFU("切比雪夫距离");
	
	private String name;
	
	private DistanceType(String name) {
		this.setName(name);
	}

	public String getName() {
		return name;
	}

	public void setName(String name) {
		this.name = name;
	}
}

elem注解


import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

@Target(ElementType.FIELD)
@Retention(RetentionPolicy.RUNTIME)
public @interface Elem {
	ElemType type() default ElemType.BASIC; //属性类型
	String[] list() default {}; //选择项
}

elemType枚举类

/**
 * 元素属性类型(标称属性、序数属性、数值属性、二元属性)
 * @author zygswo
 *
 */
public enum ElemType {
	BASIC("标称属性"),
	XUSHU("序数属性"),
	NUMBER("数值属性"),
	ERYUAN("二元属性");
	
	private String name;
	
	private ElemType(String name) {
		this.setName(name);
	}

	public String getName() {
		return name;
	}

	public void setName(String name) {
		this.name = name;
	}
}

6. DBSCAN聚类算法实现原理

算法步骤如下:

  1. 遍历样本集,获取核心对象集,记为Ω(邻域范围内的对象数量超过最小点数的对象),每一个核心对象作为初始簇。
  2. 随机选择一个核心对象,记为α,插入队列。
  3. 从队列中获取第一个核心对象,记为β。
  4. 获取β的邻域对象集,遍历当前邻域对象集,依次不重复的加入队列,并从样本集中剔除领域对象;如果没有领域对象就继续。
  5. 重复步骤3和4,直到队列为空。
  6. 从核心对象集Ω中剔除当前核心对象α,设置α的最终簇为原队列中的所有对象。
  7. 如果核心对象集Ω为空就结束,否则重复步骤1至6。
  8. 得到k个最终簇,返回结果。

Java版本代码如下:

package dbscancluster;

import java.util.ArrayList;
import java.util.List;
import java.util.Queue;
import java.util.Random;
import java.util.concurrent.LinkedBlockingQueue;

/**
 * 密度聚类算法
 * @author zygswo
 *
 */
public class DbScanUtils {
	class Item {
		@Override
		public String toString() {
			return "Item [density=" + density + ", rate=" + rate + ", sweetRate=" + sweetRate + "]";
		}
		public double getDensity() {
			return density;
		}
		public void setDensity(double density) {
			this.density = density;
		}
		public double getRate() {
			return rate;
		}
		public void setRate(double rate) {
			this.rate = rate;
		}
		public double getSweetRate() {
			return sweetRate;
		}
		public void setSweetRate(double sweetRate) {
			this.sweetRate = sweetRate;
		}
		double density;
		double rate;
		double sweetRate;
		public Item() {
			super();
		}
		public Item(double density, double rate) {
			super();
			this.density = density;
			this.rate = rate;
		}	
		public Item(double density, double rate,double sweetRate) {
			this(density,rate);
			this.sweetRate = sweetRate;
		}	
	}
	/**
	 * main
	 * @param args
	 */
	public static void main(String[] args) {
		System.out.println("----------------------------------------");
		System.out.println("------------- DbScan 密度聚类  -------------");
		System.out.println("----------------------------------------");
		DbScanUtils density = new DbScanUtils();
		Item item1 = density.new Item(6.77,0.55,0.33);
		Item item2 = density.new Item(6.57,0.45,0.12);
		Item item3 = density.new Item(6.76,0.55,0.30);
		Item item4 = density.new Item(6.58,0.45,0.14);
		Item item5 = density.new Item(4.28,0.99);
		Item item6 = density.new Item(7.28,0.48);
		Item item7 = density.new Item(6.70,0.52,0.35);
		Item item8 = density.new Item(4.32,0.96);
		List<Item> items = new ArrayList<>();
		items.add(item1);
		items.add(item2);
		items.add(item3);
		items.add(item4);
		items.add(item5);
		items.add(item6);
		items.add(item7);
		items.add(item8);
		System.out.println("---------- start ----------");
		long startTime = System.currentTimeMillis();
		List<List<Item>> result = getDensityCluster(items, 0.1, 1);
		for (List<Item> res:result) {
			System.out.println(res);
		}
		System.out.println("---------- end ----------");
		System.out.println("---------- 总耗时: " + (System.currentTimeMillis() - startTime) + "----------");
	}
	/**
	 * 获取密度聚类
	 * @param items 样本集
	 * @param distThreashold 邻域内的对象相剧最远的阈值
	 * @param sizeThreashold 核心对象所需的领域中的最少对象数量阈值
	 * @return 聚类
	 */
	public static <T> List<List<T>> getDensityCluster(List<T> items,double distThreashold,int sizeThreashold) {
		List<List<T>> result = new ArrayList<>();
		Queue<T> densityQueue = new LinkedBlockingQueue<>();
		//设置临时聚类集,初始化为样本集
		List<T> tempItemList = new ArrayList<>();
		for (T item:items) {
			tempItemList.add(item);
		}
		//获取核心对象
		List<T> coreItemList = new ArrayList<>();
		for (T item:tempItemList) {
			List<T> adjacentItemList = getAdjacent(item,tempItemList,distThreashold);
			if (adjacentItemList.size() >= sizeThreashold) {
				coreItemList.add(item);
			}
		}
		if (coreItemList.isEmpty()) {
			return result;
		}
		//判断核心对象列表是否为空
		while (!coreItemList.isEmpty()) {
			List<T> tempClusterList = new ArrayList<>(); //临时簇
			//随机抽取一个核心对象
			int i = new Random().nextInt(coreItemList.size());
			//放入队列中
			densityQueue.add(coreItemList.get(i));
			//判断队列是否为空
			while(!densityQueue.isEmpty()) {
				//获取队列中第一个对象(并从队列中删除)
				T tempCoreItem = densityQueue.poll();
				//查找当前对象的所有领域并放入队列中
				List<T> adjacentItemList = getAdjacent(tempCoreItem,tempItemList,distThreashold);
				for (T adjacentItem:adjacentItemList) {
					//查找当前对象的所有邻域并放入队列中
					if (!densityQueue.contains(adjacentItem)) {
						densityQueue.add(adjacentItem);
					}
					//邻域对象放入临时簇里
					if (!tempClusterList.contains(adjacentItem)) {
						tempClusterList.add(adjacentItem);
					}
					//从临时聚类集中删除当前对象的邻域
					if (tempItemList.contains(adjacentItem)) {
						tempItemList.remove(adjacentItem);
					}
				}
			}
			//添加簇
			if (!tempClusterList.isEmpty()) {
				result.add(tempClusterList);
			}
			//清除当前核心对象
			coreItemList.remove(i);
		}
		return result;
	}
	/**
	 * 获取邻域数组
	 * @param item 目标对象
	 * @param tempItemList 所有对象列表
	 * @param distThreashold 邻域内的对象相剧最远的阈值
	 * @return
	 */
	private static <T> List<T> getAdjacent(T item, List<T> tempItemList,double distThreashold) {
		List<T> result = new ArrayList<>();
		for (T tempItem:tempItemList) {
			//计算距离
			double dist = DiffUtils.calculDiff(item, tempItem);
			if (dist <= distThreashold) {
				result.add(tempItem);
			}
		}
		return result;
	}
}

距离计算类

package dbscancluster;

import java.lang.reflect.Field;

public class DiffUtils {
	
	/**
	 * 通过反射计算欧几里得距离
	 * @param obj1 对象1
	 * @param obj2 对象2
	 * @return 欧几里得距离
	 */
	public static <T> double calculDiff(T obj1,T obj2) {
		if (obj1 == null || obj2 == null) {
			throw new IllegalArgumentException("参数为空");
		}
		Class<?> cls = obj1.getClass();
		double total = 0;
		while(!cls.getSimpleName().equalsIgnoreCase("Object")) {
			Field[] field = cls.getDeclaredFields();
			for (Field f:field) {
				try {
					Object fVal = f.get(obj1);
					if (fVal instanceof Double) {
						double obj1Val = f.getDouble(obj1);
						double obj2Val = f.getDouble(obj2);
						total += EuclidDistance(obj1Val, obj2Val, 1.0);
					} else if (fVal instanceof Float) {
						Float obj1Val = f.getFloat(obj1);
						Float obj2Val = f.getFloat(obj2);
						total += EuclidDistance(obj1Val, obj2Val, 1.0);
					} else if (fVal instanceof Integer) {
						int obj1Val = f.getInt(obj1);
						int obj2Val = f.getInt(obj2);
						total += EuclidDistance(obj1Val, obj2Val, 1.0);
					} else if (fVal instanceof Short) {
						Short obj1Val = f.getShort(obj1);
						Short obj2Val = f.getShort(obj2);
						total += EuclidDistance(obj1Val, obj2Val, 1.0);
					} else if (fVal instanceof Long) {
						long obj1Val = f.getLong(obj1);
						long obj2Val = f.getLong(obj2);
						total += EuclidDistance(obj1Val, obj2Val, 1.0);
					}
				} catch (IllegalArgumentException | IllegalAccessException e) {
					// TODO Auto-generated catch block
					e.printStackTrace();
				}
			}
			cls = cls.getSuperclass();
		}
		//求平方根
		if (total >= 0) {
			total = Math.sqrt(total);
			total = Double.parseDouble(String.format("%.3f", total));
		} else {
			throw new IllegalArgumentException("参数计算异常");
		}
		return total;
	}
	
	/**
	 * 欧几里得距离公式
	 * @param x0
	 * @param x1
	 */
	private static double EuclidDistance(int x0, int x1,double weight){
		return Math.pow(Math.abs(x0-x1), 2) * weight;
	}
	
	/**
	 * 欧几里得距离公式
	 * @param x0
	 * @param x1
	 */
	private static double EuclidDistance(long x0, long x1,double weight){
		return Math.pow(Math.abs(x0-x1), 2) * weight;
	}
	
	/**
	 * 欧几里得距离公式
	 * @param x0
	 * @param x1
	 */
	private static double EuclidDistance(double x0, double x1,double weight){
		return Math.pow(Math.abs(x0-x1), 2) * weight;
	}
}

———————————— (未完待续)————————————


网站公告

今日签到

点亮在社区的每一天
去签到