智能土木通 - 土木工程专业知识问答系统02-RAG检索模块搭建

发布于:2025-06-20 ⋅ 阅读:(20) ⋅ 点赞:(0)

一、项目目录

civil_qa_system/
├── docs/                    # 项目文档
├── config/                  # 配置文件
├── core/                    # 核心功能代码
├── knowledge_base/          # 知识库相关
├── web/                     # Web应用部分
├── cli/                     # 命令行工具
├── tests/                   # 测试代码
├── scripts/                 # 辅助脚本
├── requirements/            # 依赖管理
└── README.md                # 项目说明

二、命名规范

  • 类名:使用大驼峰命名法,例如:MyClass
  • 函数名:使用小驼峰命名法,例如:my_function
  • 变量名:使用小驼峰命名法,例如:my_variable
  • 文件夹:使用小驼峰命名法。

 三、连接大模型

在core文件下创建core/llm/qwen_client.py,这个文件是集中管理大模型相关代码。

我这里使用的是通义千问大模型,当然你也可以选用别的大模型,但是代码和配置要改一下

在core文件下新建:conf/.qwen

里面配置大模型的api-key,将下面的your_api_key_here换成你自己的api-key。没有使用过的去阿里云中申请:申请地址

# 通义千问API配置
DASHSCOPE_API_KEY=your_api_key_here

需要安装:在终端中使用pip安装就行
langchain-community>=0.0.28
python-dotenv>=1.0.0
dashscope>=1.14.0

from dotenv import load_dotenv
import os
from typing import Tuple
from langchain_community.llms.tongyi import Tongyi
from langchain_community.chat_models import ChatTongyi
from langchain_community.embeddings import DashScopeEmbeddings
import logging

# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def load_qwen_config() -> bool:
    """
    加载千问环境变量配置
    Returns:
        bool: 是否加载成功
    """
    try:
        current_dir = os.path.dirname(__file__)
        conf_file_path_qwen = os.path.join(current_dir, '..', 'conf', '.qwen')
        
        if not os.path.exists(conf_file_path_qwen):
            logger.error(f"Qwen config file not found at: {conf_file_path_qwen}")
            return False
            
        load_dotenv(dotenv_path=conf_file_path_qwen)
        return True
    except Exception as e:
        logger.exception("Failed to load Qwen configuration")
        return False

def get_qwen_models() -> Tuple[Tongyi, ChatTongyi, DashScopeEmbeddings]:
    """
    初始化并返回千问系列大模型组件
    
    Returns:
        Tuple: (llm, chat, embed) 三元组
    Raises:
        RuntimeError: 当配置加载失败或初始化失败时抛出
    """
    if not load_qwen_config():
        raise RuntimeError("Qwen configuration loading failed")
    
    try:
        # 初始化LLM
        llm = Tongyi(
            model="qwen-max",
            temperature=0.1,
            top_p=0.7,
            max_tokens=1024,
            verbose=True
        )
        
        # 初始化Chat模型
        chat = ChatTongyi(
            model="qwen-max",
            temperature=0.01,
            top_p=0.2,
            max_tokens=1024
        )
        
        # 初始化Embedding模型
        embed = DashScopeEmbeddings(
            model="text-embedding-v3"
        )
        
        logger.info("Qwen models initialized successfully")
        return llm, chat, embed
        
    except Exception as e:
        logger.exception("Failed to initialize Qwen models")
        raise RuntimeError(f"Model initialization failed: {str(e)}")

在写代码中,我们要遵守没写一个小模块都要测试的习惯,在tests中创建unit/test_qwen_client.py

编写测试代码,上面我们写的是大模型的连接,那么在测试中就要试试能不能连接

使用pip安装测试工具pytest,不用调用函数也能测试了。

import pytest
from core.llm.qwen_client import get_qwen_models


class TestQwenClient:
    def test_model_initialization(self):
        """测试模型是否能成功初始化"""
        llm, chat, embed = get_qwen_models()
        assert llm is not None
        assert chat is not None
        assert embed is not None
        return llm, chat, embed

    def test_invalid_config(self, monkeypatch):
        """测试配置错误情况"""
        monkeypatch.setenv("DASHSCOPE_API_KEY", "")
        with pytest.raises(RuntimeError):
            get_qwen_models()

 运行测试文件会返回:

Testing started at 15:40 ...
Launching pytest with arguments test_qwen_client.py::TestQwenClient::test_invalid_config --no-header --no-summary -q in D:\construction_QA_system\tests\unit

============================= test session starts =============================
collecting ... collected 1 item

test_qwen_client.py::TestQwenClient::test_invalid_config PASSED          [100%]

============================== 1 passed in 0.85s ==============================

表示测试通过

 为了防止大家弄错,我这里贴一下我自己的项目目录:

四、实现向量库类

这里使用chroma,如果没有了解过的可以去哔哩哔哩或者官网上看看

pip安装:

chromadb>=0.4.15
langchain-chroma>=0.0.4

在knowledge_base中新建storage/chroma_manager.py模块,这里面编写chroma的创建向量库,向向量库中添加文档和查询文档功能

from typing import Optional, List, Union
import chromadb
from chromadb import Settings
from langchain_chroma import Chroma
from langchain_core.embeddings import Embeddings
from langchain_core.documents import Document
import logging

class ChromaManager:
    """ChromaDB向量数据库的高级封装管理类
    
    特性:
    - 支持本地和HTTP两种连接模式
    - 自动持久化管理
    - 线程安全连接
    - 完善的错误处理
    """
    
    def __init__(self,
                 chroma_server_type: str = "local",
                 host: str = "localhost",
                 port: int = 8000,
                 persist_path: str = "chroma_db",
                 collection_name: str = "langchain",
                 embed_model: Optional[Embeddings] = None):
        """
        初始化ChromaDB连接
        
        Args:
            chroma_server_type: 连接类型 ("local"|"http")
            host: 服务器地址 (HTTP模式必需)
            port: 服务器端口 (HTTP模式必需)
            persist_path: 本地持久化路径 (本地模式必需)
            collection_name: 集合名称
            embed_model: 嵌入模型实例
        """
        self._validate_init_params(chroma_server_type, host, port, persist_path)
        
        self.client = self._create_client(chroma_server_type, host, port, persist_path)
        self.collection_name = collection_name
        self.embed_model = embed_model
        self.logger = logging.getLogger(__name__)
        
        try:
            self.store = Chroma(
                collection_name=collection_name,
                embedding_function=embed_model,
                client=self.client,
                persist_directory=persist_path if chroma_server_type == "local" else None
            )
            self.logger.info(f"ChromaDB initialized successfully. Mode: {chroma_server_type}")
        except Exception as e:
            self.logger.error(f"ChromaDB initialization failed: {str(e)}")
            raise RuntimeError(f"Failed to initialize ChromaDB: {str(e)}")

    def _validate_init_params(self, server_type: str, host: str, port: int, path: str):
        """参数验证"""
        if server_type not in ["local", "http"]:
            raise ValueError(f"Invalid server type: {server_type}. Must be 'local' or 'http'")
            
        if server_type == "http" and not all([host, port]):
            raise ValueError("Host and port must be specified for HTTP mode")
            
        if server_type == "local" and not path:
            raise ValueError("Persist path must be specified for local mode")

    def _create_client(self, server_type: str, host: str, port: int, path: str) -> chromadb.Client:
        """创建Chroma客户端"""
        try:
            if server_type == "http":
                return chromadb.HttpClient(
                    host=host,
                    port=port,
                    settings=Settings(allow_reset=True)
                )
            else:
                return chromadb.PersistentClient(
                    path=path,
                    settings=Settings(
                        anonymized_telemetry=False,
                        allow_reset=True
                    )
                )
        except Exception as e:
            logging.error(f"Chroma client creation failed: {str(e)}")
            raise

    def add_documents(self, docs: Union[List[Document], List[str]]) -> List[str]:
        """
        添加文档到集合
        
        Args:
            docs: 文档列表,可以是Document对象或纯文本
            
        Returns:
            插入文档的ID列表
        """
        try:
            if not docs:
                self.logger.warning("Attempted to add empty documents list")
                return []
                
            doc_ids = self.store.add_documents(documents=docs)
            self.logger.info(f"Added {len(doc_ids)} documents to collection '{self.collection_name}'")
            return doc_ids
        except Exception as e:
            self.logger.error(f"Failed to add documents: {str(e)}")
            raise RuntimeError(f"Document addition failed: {str(e)}")

    def query(self, query_text: str, k: int = 5, ​**kwargs) -> List[Document]:
        """
        相似性查询
        
        Args:
            query_text: 查询文本
            k: 返回结果数量
            ​**kwargs: 额外查询参数
            
        Returns:
            匹配的文档列表
        """
        try:
            results = self.store.similarity_search(query_text, k=k, ​**kwargs)
            self.logger.debug(f"Query returned {len(results)} results for: {query_text}")
            return results
        except Exception as e:
            self.logger.error(f"Query failed: {str(e)}")
            raise RuntimeError(f"Query operation failed: {str(e)}")

    def get_collection_stats(self) -> dict:
        """获取集合统计信息"""
        try:
            collection = self.client.get_collection(self.collection_name)
            return {
                "count": collection.count(),
                "metadata": collection.metadata
            }
        except Exception as e:
            self.logger.error(f"Failed to get collection stats: {str(e)}")
            raise RuntimeError(f"Collection stats retrieval failed: {str(e)}")

    @property
    def store(self) -> Chroma:
        """获取LangChain Chroma实例"""
        return self._store

    @store.setter
    def store(self, value):
        self._store = value

 测试一下这个模块的功能

import pytest
from unittest.mock import MagicMock
from knowledge_base.storage.chroma_manager import ChromaManager

class TestChromaManager:
    @pytest.fixture
    def mock_embedding(self):
        mock = MagicMock()
        mock.embed_documents.return_value = [[0.1]*768]
        return mock

    def test_local_init(self, tmp_path, mock_embedding):
        """测试本地模式初始化"""
        db = ChromaManager(
            chroma_server_type="local",
            persist_path=str(tmp_path),
            embed_model=mock_embedding
        )
        assert db.store is not None

    def test_add_documents(self, tmp_path, mock_embedding):
        """测试文档添加功能"""
        db = ChromaManager(
            chroma_server_type="local",
            persist_path=str(tmp_path),
            embed_model=mock_embedding
        )
        test_docs = ["Test document 1", "Test document 2"]
        doc_ids = db.add_documents(test_docs)
        assert len(doc_ids) == 2

执行结果:

Testing started at 17:00 ...
Launching pytest with arguments test_chroma_manager.py::TestChromaManager::test_local_init --no-header --no-summary -q in D:\construction_QA_system\tests\unit

============================= test session starts =============================
collecting ... collected 1 item

test_chroma_manager.py::TestChromaManager::test_local_init 

============================== 1 passed in 1.79s ==============================
PASSED        [100%]
进程已结束,退出代码为 0

五、实现入库功能

在knowledge_base文件夹下新建builders/pdf_processor,其中实现类PDFProcessor,主要功能:

- 从指定目录加载PDF文件
- 提取文本内容
- 分块处理文本
- 将文本块存入向量数据库

 

import os
import logging
import time
from tqdm import tqdm
from typing import List, Optional
from langchain_community.document_loaders import PyMuPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.documents import Document
# 修改导入路径为新的项目结构
from knowledge_base.storage.chroma_manager import ChromaManager

class PDFProcessor:
    """
    PDF文档处理管道,负责:
    - 从指定目录加载PDF文件
    - 提取文本内容
    - 分块处理文本
    - 将文本块存入向量数据库

    参数说明:
        directory: PDF文件所在目录路径
        chroma_server_type: ChromaDB服务器类型("local"或"http")
        persist_path: ChromaDB持久化存储路径(本地模式使用)
        embed: 文本嵌入模型实例
        file_group_num: 每组处理的文件数(默认80)
        batch_num: 每次插入的批次数量(默认6)
        chunksize: 文本分块大小(默认500字符)
        overlap: 分块重叠大小(默认100字符)
    """

    def __init__(self,
                 directory: str,
                 chroma_server_type: str = "local",
                 persist_path: str = "chroma_db",
                 embedding_function: Optional[object] = None,
                 file_group_num: int = 80,
                 batch_num: int = 6,
                 chunksize: int = 500,
                 overlap: int = 100):

        # 参数初始化
        self.directory = directory
        self.file_group_num = file_group_num
        self.batch_num = batch_num
        self.chunksize = chunksize
        self.overlap = overlap

        # 初始化ChromaDB连接(更新类名)
        self.chroma_db = ChromaManager(
            chroma_server_type=chroma_server_type,
            persist_path=persist_path,
            embedding_function=embedding_function
        )

        # 配置日志系统(日志文件路径调整为相对路径)
        self._setup_logging()

        # 验证目录存在
        if not os.path.isdir(self.directory):
            raise ValueError(f"指定目录不存在: {self.directory}")

    def _setup_logging(self):
        """配置日志系统"""
        log_dir = "logs"
        os.makedirs(log_dir, exist_ok=True)
        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s - %(levelname)s - %(message)s',
            datefmt='%Y-%m-%d %H:%M:%S',
            handlers=[
                logging.FileHandler(os.path.join(log_dir, "pdf_processor.log")),
                logging.StreamHandler()
            ]
        )
        self.logger = logging.getLogger(__name__)
    def load_pdf_files(self) -> List[str]:
        """
        扫描目录并返回所有PDF文件路径

        返回:
            包含完整PDF文件路径的列表

        异常:
            ValueError: 如果目录中没有PDF文件
        """
        pdf_files = []
        for file in os.listdir(self.directory):
            if file.lower().endswith('.pdf'):
                pdf_files.append(os.path.join(self.directory, file))

        if not pdf_files:
            raise ValueError(f"目录中没有找到PDF文件: {self.directory}")

        self.logger.info(f"发现 {len(pdf_files)} 个PDF文件")
        return pdf_files

    def load_pdf_content(self, pdf_path: str) -> List[Document]:
        """
        使用PyMuPDF加载单个PDF文件内容

        参数:
            pdf_path: PDF文件路径

        返回:
            LangChain Document对象列表

        异常:
            RuntimeError: 如果文件加载失败
        """
        try:
            loader = PyMuPDFLoader(file_path=pdf_path)
            docs = loader.load()
            self.logger.debug(f"成功加载: {pdf_path} (共 {len(docs)} 页)")
            return docs
        except Exception as e:
            self.logger.error(f"加载PDF失败 {pdf_path}: {str(e)}")
            raise RuntimeError(f"无法加载PDF文件: {pdf_path}")

    def split_text(self, documents: List[Document]) -> List[Document]:
        """
        使用递归字符分割器将文档分块

        参数:
            documents: 待分割的Document列表

        返回:
            分割后的Document列表
        """
        text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=self.chunksize,
            chunk_overlap=self.overlap,
            length_function=len,
            add_start_index=True,
            separators=["\n\n", "\n", "。", "!", "?", ";", " ", ""]  # 中文友好分割符
        )

        try:
            docs = text_splitter.split_documents(documents)
            self.logger.info(f"文本分割完成: 原始 {len(documents)} 块 → 分割后 {len(docs)} 块")
            return docs
        except Exception as e:
            self.logger.error(f"文本分割失败: {str(e)}")
            raise RuntimeError("文本分割过程中发生错误")

    def insert_docs_chromadb(self, docs: List[Document], batch_size: int = 6) -> None:
        """
        将文档分批插入ChromaDB,带进度条和性能监控
        """
        if not docs:
            self.logger.warning("尝试插入空文档列表")
            return

        self.logger.info(f"开始插入 {len(docs)} 个文档到ChromaDB")
        start_time = time.time()
        total_docs_inserted = 0
        total_batches = (len(docs) + batch_size - 1) // batch_size

        try:
            with tqdm(total=total_batches, desc="插入进度", unit="batch") as pbar:
                for i in range(0, len(docs), batch_size):
                    batch = docs[i:i + batch_size]

                    # 更新方法调用(原add_with_langchain改为更标准的方法名)
                    self.chroma_db.add_documents(batch)
                    total_docs_inserted += len(batch)

                    # 计算吞吐量(每分钟处理文档数)
                    elapsed_time = time.time() - start_time
                    tpm = (total_docs_inserted / elapsed_time) * 60 if elapsed_time > 0 else 0

                    # 更新进度条
                    pbar.set_postfix({
                        "TPM": f"{tpm:.2f}",
                        "文档数": total_docs_inserted
                    })
                    pbar.update(1)

            self.logger.info(f"文档插入完成! 总耗时: {time.time() - start_time:.2f}秒")
        except Exception as e:
            self.logger.error(f"文档插入失败: {str(e)}")
            raise RuntimeError(f"文档插入失败: {str(e)}")

    def process_pdfs_group(self, pdf_files_group: List[str]) -> None:
        """
        处理一组PDF文件(读取→分割→存储)

        参数:
            pdf_files_group: PDF文件路径列表
        """
        try:
            # 阶段1: 加载所有PDF内容
            pdf_contents = []
            for pdf_path in pdf_files_group:
                documents = self.load_pdf_content(pdf_path)
                pdf_contents.extend(documents)

            # 阶段2: 文本分割
            if pdf_contents:
                docs = self.split_text(pdf_contents)

                # 阶段3: 存储到向量数据库
                if docs:
                    self.insert_docs_chromadb(docs, self.batch_num)
        except Exception as e:
            self.logger.error(f"处理PDF组失败: {str(e)}")
            # 可以选择继续处理下一组而不是终止
            # raise

    def process_pdfs(self) -> None:
        """
        主处理流程: 扫描目录→分组处理所有PDF文件
        """
        self.logger.info("=== 开始PDF处理流程 ===")
        start_time = time.time()

        try:
            pdf_files = self.load_pdf_files()

            # 分组处理PDF文件
            for i in range(0, len(pdf_files), self.file_group_num):
                group = pdf_files[i:i + self.file_group_num]
                self.logger.info(
                    f"正在处理文件组 {i // self.file_group_num + 1}/{(len(pdf_files) - 1) // self.file_group_num + 1}")
                self.process_pdfs_group(group)

            self.logger.info(f"=== 处理完成! 总耗时: {time.time() - start_time:.2f}秒 ===")
            print("PDF处理成功完成!")
        except Exception as e:
            self.logger.error(f"PDF处理流程失败: {str(e)}")
            raise RuntimeError(f"PDF处理失败: {str(e)}")

测试一下这个模块

在tests文件夹下新建unit /test_pdf_processor.py

import os
import pytest
from knowledge_base.builders.pdf_processor import PDFProcessor


@pytest.fixture
def test_resources(tmp_path):
    """测试资源准备"""
    # 创建PDF测试目录
    pdf_dir = tmp_path / "pdfs"
    pdf_dir.mkdir()

    # 复制预制PDF(或动态生成)
    test_pdf = os.path.join(os.path.dirname(__file__), "test_files", "sample.pdf")
    target_pdf = pdf_dir / "test.pdf"

    with open(test_pdf, "rb") as src, open(target_pdf, "wb") as dst:
        dst.write(src.read())

    return {
        "pdf_dir": str(pdf_dir),
        "db_dir": str(tmp_path / "chroma_db"),
        "pdf_path": str(target_pdf)
    }


def test_pdf_processing(test_resources):
    processor = PDFProcessor(
        directory=test_resources["pdf_dir"],
        persist_path=test_resources["db_dir"]
    )

    processor.process_pdfs()

    # 验证数据库
    assert os.path.exists(test_resources["db_dir"])
    assert any(os.listdir(test_resources["db_dir"]))

运行后效果:

Testing started at 19:35 ...
Launching pytest with arguments D:\construction_QA_system\tests\unit\test_pdf_processor.py --no-header --no-summary -q in D:\construction_QA_system\tests\unit

============================= test session starts =============================
collecting ... collected 1 item

test_pdf_processor.py::test_pdf_processing PASSED                        [100%]PDF处理成功完成!


============================== 1 passed in 2.78s ==============================

进程已结束,退出代码为 0

七、向量检索模块

在文件knowledge_base中新建retrieval/vector_retriever.py

from typing import List, Dict, Optional
from langchain_core.documents import Document
from ..storage.chroma_manager import ChromaManager


class VectorRetriever:
    def __init__(
            self,
            chroma_server_type: str = "local",
            persist_path: str = "chroma_db",
            collection_name: str = "construction_docs",
            embedding_function: Optional[object] = None,
            top_k: int = 5
    ):
        """向量检索器

        Args:
            top_k: 返回最相关的K个结果
            score_threshold: 相似度阈值
        """
        self.chroma_db = ChromaManager(
            chroma_server_type=chroma_server_type,
            persist_path=persist_path,
            collection_name=collection_name,
            embedding_function=embedding_function
        )
        self.top_k = top_k

    def similarity_search(
            self,
            query: str,
            filter_conditions: Optional[Dict] = None
    ) -> List[Document]:
        """相似度搜索"""
        collection = self.chroma_db.collection

        # 获取查询向量(假设embedding_function已配置)
        query_embedding = self.chroma_db.embedding_function.embed_query(query)

        # 执行查询
        results = collection.query(
            query_embeddings=[query_embedding],
            n_results=self.top_k,
            where=filter_conditions
        )

        # 转换为Document对象
        docs = []
        for i in range(len(results['ids'][0])):
            doc = Document(
                page_content=results['documents'][0][i],
                metadata=results['metadatas'][0][i] or {}
            )
            docs.append(doc)

        return docs

    def hybrid_search(
            self,
            query: str,
            keyword: Optional[str] = None,
            filter_conditions: Optional[Dict] = None
    ) -> List[Document]:
        """混合检索(向量+关键词)"""
        # 先执行向量搜索
        vector_results = self.similarity_search(query, filter_conditions)

        # 如果有关键词,进行过滤
        if keyword:
            filtered = [
                doc for doc in vector_results
                if keyword.lower() in doc.page_content.lower()
            ]
            return filtered[:self.top_k]

        return vector_results

    def get_by_id(self, doc_id: str) -> Optional[Document]:
        """根据ID获取文档"""
        result = self.chroma_db.collection.get(ids=[doc_id])
        if not result['documents']:
            return None

        return Document(
            page_content=result['documents'][0],
            metadata=result['metadatas'][0] or {}
        )

FastAPI接口封装

 跟目录下新建api/retrieval_api.py

import logging
from typing import List, Optional
from fastapi import FastAPI, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from datetime import datetime

# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI(
    title="Construction QA Retrieval API",
    description="建筑工程知识库检索接口",
    version="1.0.0",
    openapi_tags=[{
        "name": "检索",
        "description": "知识库检索相关接口"
    }]
)

# 允许跨域
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


# --- 数据模型 ---
class DocumentMetadata(BaseModel):
    """文档元数据模型"""
    source: Optional[str] = Field(None, example="GB/T 50081-2019")
    page: Optional[int] = Field(None, example=12)
    timestamp: Optional[datetime] = Field(None, example="2023-01-01T00:00:00")


class DocumentResponse(BaseModel):
    """检索结果模型"""
    id: str = Field(..., example="doc_123")
    content: str = Field(..., example="混凝土强度检测标准...")
    metadata: DocumentMetadata
    score: float = Field(..., ge=0, le=1, example=0.85)


class QueryRequest(BaseModel):
    """查询请求模型"""
    query: str = Field(..., min_length=1, example="混凝土强度标准")
    top_k: Optional[int] = Field(5, gt=0, le=20, example=3)
    keyword_filter: Optional[str] = Field(None, example="钢筋")
    metadata_filter: Optional[dict] = Field(None, example={"source": "GB"})


class HealthCheckResponse(BaseModel):
    """健康检查响应"""
    status: str = Field(..., example="OK")
    version: str = Field(..., example="1.0.0")


# --- 核心逻辑 ---
def initialize_retriever():
    """初始化检索器(实际项目应使用依赖注入)"""
    from knowledge_base.retrieval.vector_retriever import VectorRetriever
    from core.utils.embedding_utils import load_embedding_model

    try:
        return VectorRetriever(
            persist_path="data/vector_db",
            embedding_function=load_embedding_model(),
            top_k=10
        )
    except Exception as e:
        logger.error(f"检索器初始化失败: {str(e)}")
        raise


retriever = initialize_retriever()


# --- API端点 ---
@app.get("/health", response_model=HealthCheckResponse, tags=["系统"])
async def health_check():
    """服务健康检查"""
    return {
        "status": "OK",
        "version": "1.0.0"
    }


@app.post("/search",
          response_model=List[DocumentResponse],
          tags=["检索"],
          summary="文档检索",
          responses={
              200: {"description": "成功返回检索结果"},
              400: {"description": "无效请求参数"},
              500: {"description": "服务器内部错误"}
          })
async def search_documents(request: QueryRequest):
    """
    执行文档检索,支持以下方式:
    - 纯向量检索
    - 关键词过滤检索
    - 元数据过滤检索
    """
    try:
        logger.info(f"收到检索请求: {request.dict()}")

        # 参数验证
        if len(request.query) > 500:
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail="查询文本过长(最大500字符)"
            )

        # 执行检索
        if request.keyword_filter or request.metadata_filter:
            docs = retriever.hybrid_search(
                query=request.query,
                keyword=request.keyword_filter,
                filter_conditions=request.metadata_filter
            )
        else:
            docs = retriever.similarity_search(
                query=request.query,
                filter_conditions=request.metadata_filter
            )

        # 格式化结果
        results = []
        for doc in docs[:request.top_k]:
            if not hasattr(doc, 'metadata'):
                doc.metadata = {}

            results.append({
                "id": str(hash(doc.page_content)),
                "content": doc.page_content,
                "metadata": doc.metadata,
                "score": doc.metadata.get("score", 0.0)
            })

        logger.info(f"返回 {len(results)} 条结果")
        return results

    except HTTPException:
        raise
    except Exception as e:
        logger.error(f"检索失败: {str(e)}", exc_info=True)
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail="检索服务暂时不可用"
        )


@app.get("/document/{doc_id}",
         response_model=DocumentResponse,
         tags=["检索"],
         summary="按ID获取文档")
async def get_document(doc_id: str):
    """通过文档ID获取完整内容"""
    try:
        doc = retriever.get_by_id(doc_id)
        if not doc:
            raise HTTPException(
                status_code=status.HTTP_404_NOT_FOUND,
                detail="文档不存在"
            )

        return {
            "id": doc_id,
            "content": doc.page_content,
            "metadata": doc.metadata or {},
            "score": 1.0
        }
    except Exception as e:
        logger.error(f"获取文档失败: {str(e)}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail="文档获取失败"
        )


# --- 启动配置 ---
if __name__ == "__main__":
    import uvicorn

    uvicorn.run(
        app,
        host="0.0.0.0",
        port=8000,
        log_config={
            "version": 1,
            "disable_existing_loggers": False,
            "handlers": {
                "console": {
                    "class": "logging.StreamHandler",
                    "level": "INFO",
                    "formatter": "default"
                }
            },
            "formatters": {
                "default": {
                    "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
                }
            },
            "root": {
                "handlers": ["console"],
                "level": "INFO"
            }
        }
    )

测试用例:tests/unit/test_embedding_utils.py

import pytest
from unittest.mock import patch
from core.utils.embedding_utils import load_embedding_model

class TestEmbeddingUtils:
    @patch('langchain.embeddings.HuggingFaceEmbeddings')
    def test_load_huggingface(self, mock_embeddings):
        """测试加载HuggingFace模型"""
        model = load_embedding_model(model_type="huggingface")
        mock_embeddings.assert_called_once()
        
    @patch.dict('os.environ', {'OPENAI_API_KEY': 'test_key'})
    @patch('langchain.embeddings.OpenAIEmbeddings')
    def test_load_openai(self, mock_embeddings):
        """测试加载OpenAI模型"""
        model = load_embedding_model(model_type="openai")
        mock_embeddings.assert_called_once_with(
            model="text-embedding-3-small",
            deployment=None,
            openai_api_key="test_key"
        )
        
    def test_invalid_model_type(self):
        """测试无效模型类型"""
        with pytest.raises(ValueError):
            load_embedding_model(model_type="invalid_type")

 


网站公告

今日签到

点亮在社区的每一天
去签到