Java 实现局部敏感Hash
局部敏感哈希(Locality-Sensitive Hashing, LSH)是一种哈希技术,用于将相似的数据映射到相同的哈希值,从而在大数据集中快速找到相似项。
import java.util.*;
/**
* 局部敏感哈希(Locality-Sensitive Hashing, LSH)类。
* 用于将相似的数据项映射到相同的“桶”中,从而高效地进行近似最近邻搜索。
*/
public class LSH {
/**
* 哈希函数的数量。
*/
private final int numHashFunctions;
/**
* 桶的数量。
*/
private final int numBands;
/**
* 存储数据项的桶集合。
*/
private final List<Set<String>> buckets;
/**
* 构造函数,初始化 LSH 实例。
*
* @param numHashFunctions 哈希函数的数量。
* @param numBands 桶的数量。
*/
public LSH(int numHashFunctions, int numBands) {
this.numHashFunctions = numHashFunctions;
this.numBands = numBands;
this.buckets = new ArrayList<>(numBands);
for (int i = 0; i < numBands; i++) {
buckets.add(new HashSet<>());
}
}
/**
* 将数据项添加到相应的桶中。
*
* @param item 数据项。
* @param hashes 数据项对应的哈希值数组。
* @throws IllegalArgumentException 如果哈希值数组长度与哈希函数数量不匹配。
*/
public void add(String item, String[] hashes) {
if (hashes.length != numHashFunctions) {
throw new IllegalArgumentException("Number of hashes must match the number of hash functions");
}
for (int bandIndex = 0; bandIndex < numBands; bandIndex++) {
StringBuilder bandKey = new StringBuilder();
for (int hashIndex = 0; hashIndex < numHashFunctions / numBands; hashIndex++) {
bandKey.append(hashes[bandIndex * (numHashFunctions / numBands) + hashIndex]);
bandKey.append(",");
}
buckets.get(bandIndex).add(item + ":" + bandKey.toString());
}
}
/**
* 根据查询哈希值查找可能的近似匹配项。
*
* @param hashes 查询哈希值数组。
* @return 可能的近似匹配项集合。
* @throws IllegalArgumentException 如果哈希值数组长度与哈希函数数量不匹配。
*/
public Set<String> query(String[] hashes) {
Set<String> results = new HashSet<>();
if (hashes.length != numHashFunctions) {
throw new IllegalArgumentException("Number of hashes must match the number of hash functions");
}
for (int bandIndex = 0; bandIndex < numBands; bandIndex++) {
StringBuilder bandKey = new StringBuilder();
for (int hashIndex = 0; hashIndex < numHashFunctions / numBands; hashIndex++) {
bandKey.append(hashes[bandIndex * (numHashFunctions / numBands) + hashIndex]);
bandKey.append(",");
}
for (String bucketItem : buckets.get(bandIndex)) {
if (bucketItem.endsWith(bandKey.toString())) {
results.add(bucketItem.split(":")[0]);
}
}
}
return results;
}
/**
* 获取哈希函数的数量。
*
* @return 哈希函数的数量。
*/
public int getNumHashFunctions() {
return numHashFunctions;
}
}
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
public class SimpleHashFunction {
private final MessageDigest digest;
public SimpleHashFunction() throws NoSuchAlgorithmException {
this.digest = MessageDigest.getInstance("SHA-256");
}
public String hash(String input) {
byte[] hashBytes = digest.digest(input.getBytes());
StringBuilder hexString = new StringBuilder();
for (byte b : hashBytes) {
String hex = Integer.toHexString(0xff & b);
if (hex.length() == 1) hexString.append('0');
hexString.append(hex);
}
return hexString.toString();
}
}
package com.collmall.shortlink;
import java.security.NoSuchAlgorithmException;
import java.util.Arrays;
import java.util.List;
import java.util.Set;
public class Main {
public static void main(String[] args) throws NoSuchAlgorithmException {
LSH lsh = new LSH(4, 2);
SimpleHashFunction hashFunction = new SimpleHashFunction();
List<String> items = Arrays.asList("apple", "apples", "banana", "bananas", "grape", "grapes");
for (String item : items) {
String[] hashes = new String[lsh.getNumHashFunctions()];
for (int i = 0; i < lsh.getNumHashFunctions(); i++) {
hashes[i] = hashFunction.hash(item + i);
}
lsh.add(item, hashes);
}
String queryItem = "apple";
String[] queryHashes = new String[lsh.getNumHashFunctions()];
for (int i = 0; i < lsh.getNumHashFunctions(); i++) {
queryHashes[i] = hashFunction.hash(queryItem + i);
}
Set<String> results = lsh.query(queryHashes);
System.out.println("Results: " + results);
}
}