Java 实现局部敏感Hash

发布于:2024-12-20 ⋅ 阅读:(11) ⋅ 点赞:(0)

Java 实现局部敏感Hash

局部敏感哈希(Locality-Sensitive Hashing, LSH)是一种哈希技术,用于将相似的数据映射到相同的哈希值,从而在大数据集中快速找到相似项。



import java.util.*;

/**
 * 局部敏感哈希(Locality-Sensitive Hashing, LSH)类。
 * 用于将相似的数据项映射到相同的“桶”中,从而高效地进行近似最近邻搜索。
 */
public class LSH {
    /**
     * 哈希函数的数量。
     */
    private final int numHashFunctions;

    /**
     * 桶的数量。
     */
    private final int numBands;

    /**
     * 存储数据项的桶集合。
     */
    private final List<Set<String>> buckets;

    /**
     * 构造函数,初始化 LSH 实例。
     *
     * @param numHashFunctions 哈希函数的数量。
     * @param numBands         桶的数量。
     */
    public LSH(int numHashFunctions, int numBands) {
        this.numHashFunctions = numHashFunctions;
        this.numBands = numBands;
        this.buckets = new ArrayList<>(numBands);
        for (int i = 0; i < numBands; i++) {
            buckets.add(new HashSet<>());
        }
    }

    /**
     * 将数据项添加到相应的桶中。
     *
     * @param item   数据项。
     * @param hashes 数据项对应的哈希值数组。
     * @throws IllegalArgumentException 如果哈希值数组长度与哈希函数数量不匹配。
     */
    public void add(String item, String[] hashes) {
        if (hashes.length != numHashFunctions) {
            throw new IllegalArgumentException("Number of hashes must match the number of hash functions");
        }
        for (int bandIndex = 0; bandIndex < numBands; bandIndex++) {
            StringBuilder bandKey = new StringBuilder();
            for (int hashIndex = 0; hashIndex < numHashFunctions / numBands; hashIndex++) {
                bandKey.append(hashes[bandIndex * (numHashFunctions / numBands) + hashIndex]);
                bandKey.append(",");
            }
            buckets.get(bandIndex).add(item + ":" + bandKey.toString());
        }
    }

    /**
     * 根据查询哈希值查找可能的近似匹配项。
     *
     * @param hashes 查询哈希值数组。
     * @return 可能的近似匹配项集合。
     * @throws IllegalArgumentException 如果哈希值数组长度与哈希函数数量不匹配。
     */
    public Set<String> query(String[] hashes) {
        Set<String> results = new HashSet<>();
        if (hashes.length != numHashFunctions) {
            throw new IllegalArgumentException("Number of hashes must match the number of hash functions");
        }
        for (int bandIndex = 0; bandIndex < numBands; bandIndex++) {
            StringBuilder bandKey = new StringBuilder();
            for (int hashIndex = 0; hashIndex < numHashFunctions / numBands; hashIndex++) {
                bandKey.append(hashes[bandIndex * (numHashFunctions / numBands) + hashIndex]);
                bandKey.append(",");
            }
            for (String bucketItem : buckets.get(bandIndex)) {
                if (bucketItem.endsWith(bandKey.toString())) {
                    results.add(bucketItem.split(":")[0]);
                }
            }
        }
        return results;
    }

    /**
     * 获取哈希函数的数量。
     *
     * @return 哈希函数的数量。
     */
    public int getNumHashFunctions() {
        return numHashFunctions;
    }
}


import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;

public class SimpleHashFunction {
    private final MessageDigest digest;

    public SimpleHashFunction() throws NoSuchAlgorithmException {
        this.digest = MessageDigest.getInstance("SHA-256");
    }

    public String hash(String input) {
        byte[] hashBytes = digest.digest(input.getBytes());
        StringBuilder hexString = new StringBuilder();
        for (byte b : hashBytes) {
            String hex = Integer.toHexString(0xff & b);
            if (hex.length() == 1) hexString.append('0');
            hexString.append(hex);
        }
        return hexString.toString();
    }
}

package com.collmall.shortlink;

import java.security.NoSuchAlgorithmException;
import java.util.Arrays;
import java.util.List;
import java.util.Set;

public class Main {
    public static void main(String[] args) throws NoSuchAlgorithmException {
        LSH lsh = new LSH(4, 2);

        SimpleHashFunction hashFunction = new SimpleHashFunction();

        List<String> items = Arrays.asList("apple", "apples", "banana", "bananas", "grape", "grapes");

        for (String item : items) {
            String[] hashes = new String[lsh.getNumHashFunctions()];
            for (int i = 0; i < lsh.getNumHashFunctions(); i++) {
                hashes[i] = hashFunction.hash(item + i);
            }
            lsh.add(item, hashes);
        }

        String queryItem = "apple";
        String[] queryHashes = new String[lsh.getNumHashFunctions()];
        for (int i = 0; i < lsh.getNumHashFunctions(); i++) {
            queryHashes[i] = hashFunction.hash(queryItem + i);
        }

        Set<String> results = lsh.query(queryHashes);
        System.out.println("Results: " + results);
    }
}