实现从 Milvus 中获取数据,并基于嵌入向量重新排序的功能

发布于:2024-09-18 ⋅ 阅读:(48) ⋅ 点赞:(0)

为了实现从 Milvus 中获取数据,并基于嵌入向量重新排序的功能,你可以参考以下步骤对原代码进行完善和修改。

关键问题和修改:

  1. Milvus 数据获取:确保 query() 能获取插入的数据,尤其是向量。确保 Collection 正确连接,并且 output_fields 中字段名匹配 Milvus 中的定义。
  2. 加载数据:插入数据后,需要通过 load() 加载数据,确保数据在 Milvus 中可见。
  3. 优化向量查询逻辑:可以考虑使用 search() 方法,而不是直接从 Milvus 查询所有数据,再计算相似度。search() 可以直接根据查询文本的向量进行向量相似度搜索。

代码修改:

import torch
from transformers import AutoTokenizer, AutoModel
import numpy as np
from pymilvus import Collection, connections


# 计算余弦相似度的函数
def cos_sim(a, b):
    return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))


# 生成文本的嵌入向量
def embed_texts(texts, model, tokenizer):
    """
    使用 Jina-embeddings-v2 模型将文本转化为嵌入向量
    """
    inputs = tokenizer(texts, return_tensors='pt', padding=True, truncation=True)
    with torch.no_grad():
        outputs = model(**inputs)
    embeddings = outputs.last_hidden_state[:, 0, :]  # 提取CLS token嵌入
    return embeddings.numpy()


# 从 Milvus 获取候选文本和嵌入向量
def get_candidates_from_milvus(collection_name, limit=10):
    """
    从 Milvus 中获取候选文本及其对应的嵌入
    """
    # 连接到 Milvus
    connections.connect(alias="default", host="localhost", port="19530")
    collection = Collection(collection_name)

    # 加载数据确保可用
    collection.load()

    # 简单查询所有候选文本和嵌入向量
    results = collection.query(expr="", output_fields=["text", "embedding"], limit=limit)

    # 提取文本和嵌入向量
    candidate_texts = [result['text'] for result in results]
    candidate_embeddings = np.array([result['embedding'] for result in results])

    return candidate_texts, candidate_embeddings


def rerank_candidates(query_text, model, tokenizer, candidate_texts, candidate_embeddings):
    """
    重新排序候选文本,基于与查询文本的相似度
    """
    # 为查询文本生成嵌入向量
    query_embedding = embed_texts([query_text], model, tokenizer)[0]

    # 计算查询与候选文本的相似度
    similarities = [cos_sim(query_embedding, embedding) for embedding in candidate_embeddings]

    # 根据相似度进行重排序
    sorted_candidates = sorted(zip(candidate_texts, similarities), key=lambda x: x[1], reverse=True)

    # 返回重排序后的结果
    return sorted_candidates


def main():
    # 加载 Jina-embeddings-v2-base-zh 模型和对应的 tokenizer
    model_name = 'jinaai/jina-embeddings-v2-base-zh'
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    model = AutoModel.from_pretrained(model_name, trust_remote_code=True)

    # 获取候选文本和嵌入向量(从 Milvus)
    candidate_texts, candidate_embeddings = get_candidates_from_milvus("weather", limit=10)

    # 查询文本
    query_text = "天气咋样?"

    # 重新排序候选文本
    sorted_candidates = rerank_candidates(query_text, model, tokenizer, candidate_texts, candidate_embeddings)

    # 输出重排序结果
    print("Ranked results based on similarity:")
    for text, score in sorted_candidates:
        print(f"Candidate: {text}, Similarity: {score}")


if __name__ == "__main__":
    main()

关键改动:

  1. collection.load():确保在查询前加载集合中的数据,使其可被查询。Milvus 使用延迟加载,所以需要显式调用 load()

  2. query():获取集合中的 textembedding 字段。根据你在 Milvus 集合中插入的字段来确定 output_fields

运行步骤:

  1. 确保 Milvus 服务器已经启动并连接正常。
  2. 插入的数据已成功提交并在 Milvus 集合中可见。
  3. 执行脚本,通过从 Milvus 中查询文本和嵌入数据,然后对查询文本进行重排序。

小提示:

如果你希望直接在 Milvus 中进行相似度搜索(不手动计算余弦相似度),可以使用 search() 方法,Milvus 本身支持基于向量的相似度查询,会更加高效。