一、项目目录
civil_qa_system/
├── docs/ # 项目文档
├── config/ # 配置文件
├── core/ # 核心功能代码
├── knowledge_base/ # 知识库相关
├── web/ # Web应用部分
├── cli/ # 命令行工具
├── tests/ # 测试代码
├── scripts/ # 辅助脚本
├── requirements/ # 依赖管理
└── README.md # 项目说明
二、命名规范
- 类名:使用大驼峰命名法,例如:
MyClass
- 函数名:使用小驼峰命名法,例如:
my_function
- 变量名:使用小驼峰命名法,例如:
my_variable
- 文件夹:使用小驼峰命名法。
三、连接大模型
在core文件下创建core/llm/qwen_client.py,这个文件是集中管理大模型相关代码。
我这里使用的是通义千问大模型,当然你也可以选用别的大模型,但是代码和配置要改一下
在core文件下新建:conf/.qwen
里面配置大模型的api-key,将下面的your_api_key_here换成你自己的api-key。没有使用过的去阿里云中申请:申请地址
# 通义千问API配置
DASHSCOPE_API_KEY=your_api_key_here
需要安装:在终端中使用pip安装就行
langchain-community>=0.0.28
python-dotenv>=1.0.0
dashscope>=1.14.0
from dotenv import load_dotenv
import os
from typing import Tuple
from langchain_community.llms.tongyi import Tongyi
from langchain_community.chat_models import ChatTongyi
from langchain_community.embeddings import DashScopeEmbeddings
import logging
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def load_qwen_config() -> bool:
"""
加载千问环境变量配置
Returns:
bool: 是否加载成功
"""
try:
current_dir = os.path.dirname(__file__)
conf_file_path_qwen = os.path.join(current_dir, '..', 'conf', '.qwen')
if not os.path.exists(conf_file_path_qwen):
logger.error(f"Qwen config file not found at: {conf_file_path_qwen}")
return False
load_dotenv(dotenv_path=conf_file_path_qwen)
return True
except Exception as e:
logger.exception("Failed to load Qwen configuration")
return False
def get_qwen_models() -> Tuple[Tongyi, ChatTongyi, DashScopeEmbeddings]:
"""
初始化并返回千问系列大模型组件
Returns:
Tuple: (llm, chat, embed) 三元组
Raises:
RuntimeError: 当配置加载失败或初始化失败时抛出
"""
if not load_qwen_config():
raise RuntimeError("Qwen configuration loading failed")
try:
# 初始化LLM
llm = Tongyi(
model="qwen-max",
temperature=0.1,
top_p=0.7,
max_tokens=1024,
verbose=True
)
# 初始化Chat模型
chat = ChatTongyi(
model="qwen-max",
temperature=0.01,
top_p=0.2,
max_tokens=1024
)
# 初始化Embedding模型
embed = DashScopeEmbeddings(
model="text-embedding-v3"
)
logger.info("Qwen models initialized successfully")
return llm, chat, embed
except Exception as e:
logger.exception("Failed to initialize Qwen models")
raise RuntimeError(f"Model initialization failed: {str(e)}")
在写代码中,我们要遵守没写一个小模块都要测试的习惯,在tests中创建unit/test_qwen_client.py
编写测试代码,上面我们写的是大模型的连接,那么在测试中就要试试能不能连接
使用pip安装测试工具pytest,不用调用函数也能测试了。
import pytest
from core.llm.qwen_client import get_qwen_models
class TestQwenClient:
def test_model_initialization(self):
"""测试模型是否能成功初始化"""
llm, chat, embed = get_qwen_models()
assert llm is not None
assert chat is not None
assert embed is not None
return llm, chat, embed
def test_invalid_config(self, monkeypatch):
"""测试配置错误情况"""
monkeypatch.setenv("DASHSCOPE_API_KEY", "")
with pytest.raises(RuntimeError):
get_qwen_models()
运行测试文件会返回:
Testing started at 15:40 ...
Launching pytest with arguments test_qwen_client.py::TestQwenClient::test_invalid_config --no-header --no-summary -q in D:\construction_QA_system\tests\unit
============================= test session starts =============================
collecting ... collected 1 item
test_qwen_client.py::TestQwenClient::test_invalid_config PASSED [100%]
============================== 1 passed in 0.85s ==============================
表示测试通过
为了防止大家弄错,我这里贴一下我自己的项目目录:
四、实现向量库类
这里使用chroma,如果没有了解过的可以去哔哩哔哩或者官网上看看
pip安装:
chromadb>=0.4.15
langchain-chroma>=0.0.4
在knowledge_base中新建storage/chroma_manager.py模块,这里面编写chroma的创建向量库,向向量库中添加文档和查询文档功能
from typing import Optional, List, Union
import chromadb
from chromadb import Settings
from langchain_chroma import Chroma
from langchain_core.embeddings import Embeddings
from langchain_core.documents import Document
import logging
class ChromaManager:
"""ChromaDB向量数据库的高级封装管理类
特性:
- 支持本地和HTTP两种连接模式
- 自动持久化管理
- 线程安全连接
- 完善的错误处理
"""
def __init__(self,
chroma_server_type: str = "local",
host: str = "localhost",
port: int = 8000,
persist_path: str = "chroma_db",
collection_name: str = "langchain",
embed_model: Optional[Embeddings] = None):
"""
初始化ChromaDB连接
Args:
chroma_server_type: 连接类型 ("local"|"http")
host: 服务器地址 (HTTP模式必需)
port: 服务器端口 (HTTP模式必需)
persist_path: 本地持久化路径 (本地模式必需)
collection_name: 集合名称
embed_model: 嵌入模型实例
"""
self._validate_init_params(chroma_server_type, host, port, persist_path)
self.client = self._create_client(chroma_server_type, host, port, persist_path)
self.collection_name = collection_name
self.embed_model = embed_model
self.logger = logging.getLogger(__name__)
try:
self.store = Chroma(
collection_name=collection_name,
embedding_function=embed_model,
client=self.client,
persist_directory=persist_path if chroma_server_type == "local" else None
)
self.logger.info(f"ChromaDB initialized successfully. Mode: {chroma_server_type}")
except Exception as e:
self.logger.error(f"ChromaDB initialization failed: {str(e)}")
raise RuntimeError(f"Failed to initialize ChromaDB: {str(e)}")
def _validate_init_params(self, server_type: str, host: str, port: int, path: str):
"""参数验证"""
if server_type not in ["local", "http"]:
raise ValueError(f"Invalid server type: {server_type}. Must be 'local' or 'http'")
if server_type == "http" and not all([host, port]):
raise ValueError("Host and port must be specified for HTTP mode")
if server_type == "local" and not path:
raise ValueError("Persist path must be specified for local mode")
def _create_client(self, server_type: str, host: str, port: int, path: str) -> chromadb.Client:
"""创建Chroma客户端"""
try:
if server_type == "http":
return chromadb.HttpClient(
host=host,
port=port,
settings=Settings(allow_reset=True)
)
else:
return chromadb.PersistentClient(
path=path,
settings=Settings(
anonymized_telemetry=False,
allow_reset=True
)
)
except Exception as e:
logging.error(f"Chroma client creation failed: {str(e)}")
raise
def add_documents(self, docs: Union[List[Document], List[str]]) -> List[str]:
"""
添加文档到集合
Args:
docs: 文档列表,可以是Document对象或纯文本
Returns:
插入文档的ID列表
"""
try:
if not docs:
self.logger.warning("Attempted to add empty documents list")
return []
doc_ids = self.store.add_documents(documents=docs)
self.logger.info(f"Added {len(doc_ids)} documents to collection '{self.collection_name}'")
return doc_ids
except Exception as e:
self.logger.error(f"Failed to add documents: {str(e)}")
raise RuntimeError(f"Document addition failed: {str(e)}")
def query(self, query_text: str, k: int = 5, **kwargs) -> List[Document]:
"""
相似性查询
Args:
query_text: 查询文本
k: 返回结果数量
**kwargs: 额外查询参数
Returns:
匹配的文档列表
"""
try:
results = self.store.similarity_search(query_text, k=k, **kwargs)
self.logger.debug(f"Query returned {len(results)} results for: {query_text}")
return results
except Exception as e:
self.logger.error(f"Query failed: {str(e)}")
raise RuntimeError(f"Query operation failed: {str(e)}")
def get_collection_stats(self) -> dict:
"""获取集合统计信息"""
try:
collection = self.client.get_collection(self.collection_name)
return {
"count": collection.count(),
"metadata": collection.metadata
}
except Exception as e:
self.logger.error(f"Failed to get collection stats: {str(e)}")
raise RuntimeError(f"Collection stats retrieval failed: {str(e)}")
@property
def store(self) -> Chroma:
"""获取LangChain Chroma实例"""
return self._store
@store.setter
def store(self, value):
self._store = value
测试一下这个模块的功能
import pytest
from unittest.mock import MagicMock
from knowledge_base.storage.chroma_manager import ChromaManager
class TestChromaManager:
@pytest.fixture
def mock_embedding(self):
mock = MagicMock()
mock.embed_documents.return_value = [[0.1]*768]
return mock
def test_local_init(self, tmp_path, mock_embedding):
"""测试本地模式初始化"""
db = ChromaManager(
chroma_server_type="local",
persist_path=str(tmp_path),
embed_model=mock_embedding
)
assert db.store is not None
def test_add_documents(self, tmp_path, mock_embedding):
"""测试文档添加功能"""
db = ChromaManager(
chroma_server_type="local",
persist_path=str(tmp_path),
embed_model=mock_embedding
)
test_docs = ["Test document 1", "Test document 2"]
doc_ids = db.add_documents(test_docs)
assert len(doc_ids) == 2
执行结果:
Testing started at 17:00 ...
Launching pytest with arguments test_chroma_manager.py::TestChromaManager::test_local_init --no-header --no-summary -q in D:\construction_QA_system\tests\unit
============================= test session starts =============================
collecting ... collected 1 item
test_chroma_manager.py::TestChromaManager::test_local_init
============================== 1 passed in 1.79s ==============================
PASSED [100%]
进程已结束,退出代码为 0
五、实现入库功能
在knowledge_base文件夹下新建builders/pdf_processor,其中实现类PDFProcessor,主要功能:
- 从指定目录加载PDF文件 - 提取文本内容 - 分块处理文本 - 将文本块存入向量数据库
import os
import logging
import time
from tqdm import tqdm
from typing import List, Optional
from langchain_community.document_loaders import PyMuPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.documents import Document
# 修改导入路径为新的项目结构
from knowledge_base.storage.chroma_manager import ChromaManager
class PDFProcessor:
"""
PDF文档处理管道,负责:
- 从指定目录加载PDF文件
- 提取文本内容
- 分块处理文本
- 将文本块存入向量数据库
参数说明:
directory: PDF文件所在目录路径
chroma_server_type: ChromaDB服务器类型("local"或"http")
persist_path: ChromaDB持久化存储路径(本地模式使用)
embed: 文本嵌入模型实例
file_group_num: 每组处理的文件数(默认80)
batch_num: 每次插入的批次数量(默认6)
chunksize: 文本分块大小(默认500字符)
overlap: 分块重叠大小(默认100字符)
"""
def __init__(self,
directory: str,
chroma_server_type: str = "local",
persist_path: str = "chroma_db",
embedding_function: Optional[object] = None,
file_group_num: int = 80,
batch_num: int = 6,
chunksize: int = 500,
overlap: int = 100):
# 参数初始化
self.directory = directory
self.file_group_num = file_group_num
self.batch_num = batch_num
self.chunksize = chunksize
self.overlap = overlap
# 初始化ChromaDB连接(更新类名)
self.chroma_db = ChromaManager(
chroma_server_type=chroma_server_type,
persist_path=persist_path,
embedding_function=embedding_function
)
# 配置日志系统(日志文件路径调整为相对路径)
self._setup_logging()
# 验证目录存在
if not os.path.isdir(self.directory):
raise ValueError(f"指定目录不存在: {self.directory}")
def _setup_logging(self):
"""配置日志系统"""
log_dir = "logs"
os.makedirs(log_dir, exist_ok=True)
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
handlers=[
logging.FileHandler(os.path.join(log_dir, "pdf_processor.log")),
logging.StreamHandler()
]
)
self.logger = logging.getLogger(__name__)
def load_pdf_files(self) -> List[str]:
"""
扫描目录并返回所有PDF文件路径
返回:
包含完整PDF文件路径的列表
异常:
ValueError: 如果目录中没有PDF文件
"""
pdf_files = []
for file in os.listdir(self.directory):
if file.lower().endswith('.pdf'):
pdf_files.append(os.path.join(self.directory, file))
if not pdf_files:
raise ValueError(f"目录中没有找到PDF文件: {self.directory}")
self.logger.info(f"发现 {len(pdf_files)} 个PDF文件")
return pdf_files
def load_pdf_content(self, pdf_path: str) -> List[Document]:
"""
使用PyMuPDF加载单个PDF文件内容
参数:
pdf_path: PDF文件路径
返回:
LangChain Document对象列表
异常:
RuntimeError: 如果文件加载失败
"""
try:
loader = PyMuPDFLoader(file_path=pdf_path)
docs = loader.load()
self.logger.debug(f"成功加载: {pdf_path} (共 {len(docs)} 页)")
return docs
except Exception as e:
self.logger.error(f"加载PDF失败 {pdf_path}: {str(e)}")
raise RuntimeError(f"无法加载PDF文件: {pdf_path}")
def split_text(self, documents: List[Document]) -> List[Document]:
"""
使用递归字符分割器将文档分块
参数:
documents: 待分割的Document列表
返回:
分割后的Document列表
"""
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=self.chunksize,
chunk_overlap=self.overlap,
length_function=len,
add_start_index=True,
separators=["\n\n", "\n", "。", "!", "?", ";", " ", ""] # 中文友好分割符
)
try:
docs = text_splitter.split_documents(documents)
self.logger.info(f"文本分割完成: 原始 {len(documents)} 块 → 分割后 {len(docs)} 块")
return docs
except Exception as e:
self.logger.error(f"文本分割失败: {str(e)}")
raise RuntimeError("文本分割过程中发生错误")
def insert_docs_chromadb(self, docs: List[Document], batch_size: int = 6) -> None:
"""
将文档分批插入ChromaDB,带进度条和性能监控
"""
if not docs:
self.logger.warning("尝试插入空文档列表")
return
self.logger.info(f"开始插入 {len(docs)} 个文档到ChromaDB")
start_time = time.time()
total_docs_inserted = 0
total_batches = (len(docs) + batch_size - 1) // batch_size
try:
with tqdm(total=total_batches, desc="插入进度", unit="batch") as pbar:
for i in range(0, len(docs), batch_size):
batch = docs[i:i + batch_size]
# 更新方法调用(原add_with_langchain改为更标准的方法名)
self.chroma_db.add_documents(batch)
total_docs_inserted += len(batch)
# 计算吞吐量(每分钟处理文档数)
elapsed_time = time.time() - start_time
tpm = (total_docs_inserted / elapsed_time) * 60 if elapsed_time > 0 else 0
# 更新进度条
pbar.set_postfix({
"TPM": f"{tpm:.2f}",
"文档数": total_docs_inserted
})
pbar.update(1)
self.logger.info(f"文档插入完成! 总耗时: {time.time() - start_time:.2f}秒")
except Exception as e:
self.logger.error(f"文档插入失败: {str(e)}")
raise RuntimeError(f"文档插入失败: {str(e)}")
def process_pdfs_group(self, pdf_files_group: List[str]) -> None:
"""
处理一组PDF文件(读取→分割→存储)
参数:
pdf_files_group: PDF文件路径列表
"""
try:
# 阶段1: 加载所有PDF内容
pdf_contents = []
for pdf_path in pdf_files_group:
documents = self.load_pdf_content(pdf_path)
pdf_contents.extend(documents)
# 阶段2: 文本分割
if pdf_contents:
docs = self.split_text(pdf_contents)
# 阶段3: 存储到向量数据库
if docs:
self.insert_docs_chromadb(docs, self.batch_num)
except Exception as e:
self.logger.error(f"处理PDF组失败: {str(e)}")
# 可以选择继续处理下一组而不是终止
# raise
def process_pdfs(self) -> None:
"""
主处理流程: 扫描目录→分组处理所有PDF文件
"""
self.logger.info("=== 开始PDF处理流程 ===")
start_time = time.time()
try:
pdf_files = self.load_pdf_files()
# 分组处理PDF文件
for i in range(0, len(pdf_files), self.file_group_num):
group = pdf_files[i:i + self.file_group_num]
self.logger.info(
f"正在处理文件组 {i // self.file_group_num + 1}/{(len(pdf_files) - 1) // self.file_group_num + 1}")
self.process_pdfs_group(group)
self.logger.info(f"=== 处理完成! 总耗时: {time.time() - start_time:.2f}秒 ===")
print("PDF处理成功完成!")
except Exception as e:
self.logger.error(f"PDF处理流程失败: {str(e)}")
raise RuntimeError(f"PDF处理失败: {str(e)}")
测试一下这个模块
在tests文件夹下新建unit /test_pdf_processor.py
import os
import pytest
from knowledge_base.builders.pdf_processor import PDFProcessor
@pytest.fixture
def test_resources(tmp_path):
"""测试资源准备"""
# 创建PDF测试目录
pdf_dir = tmp_path / "pdfs"
pdf_dir.mkdir()
# 复制预制PDF(或动态生成)
test_pdf = os.path.join(os.path.dirname(__file__), "test_files", "sample.pdf")
target_pdf = pdf_dir / "test.pdf"
with open(test_pdf, "rb") as src, open(target_pdf, "wb") as dst:
dst.write(src.read())
return {
"pdf_dir": str(pdf_dir),
"db_dir": str(tmp_path / "chroma_db"),
"pdf_path": str(target_pdf)
}
def test_pdf_processing(test_resources):
processor = PDFProcessor(
directory=test_resources["pdf_dir"],
persist_path=test_resources["db_dir"]
)
processor.process_pdfs()
# 验证数据库
assert os.path.exists(test_resources["db_dir"])
assert any(os.listdir(test_resources["db_dir"]))
运行后效果:
Testing started at 19:35 ...
Launching pytest with arguments D:\construction_QA_system\tests\unit\test_pdf_processor.py --no-header --no-summary -q in D:\construction_QA_system\tests\unit
============================= test session starts =============================
collecting ... collected 1 item
test_pdf_processor.py::test_pdf_processing PASSED [100%]PDF处理成功完成!
============================== 1 passed in 2.78s ==============================
进程已结束,退出代码为 0
七、向量检索模块
在文件knowledge_base中新建retrieval/vector_retriever.py
from typing import List, Dict, Optional
from langchain_core.documents import Document
from ..storage.chroma_manager import ChromaManager
class VectorRetriever:
def __init__(
self,
chroma_server_type: str = "local",
persist_path: str = "chroma_db",
collection_name: str = "construction_docs",
embedding_function: Optional[object] = None,
top_k: int = 5
):
"""向量检索器
Args:
top_k: 返回最相关的K个结果
score_threshold: 相似度阈值
"""
self.chroma_db = ChromaManager(
chroma_server_type=chroma_server_type,
persist_path=persist_path,
collection_name=collection_name,
embedding_function=embedding_function
)
self.top_k = top_k
def similarity_search(
self,
query: str,
filter_conditions: Optional[Dict] = None
) -> List[Document]:
"""相似度搜索"""
collection = self.chroma_db.collection
# 获取查询向量(假设embedding_function已配置)
query_embedding = self.chroma_db.embedding_function.embed_query(query)
# 执行查询
results = collection.query(
query_embeddings=[query_embedding],
n_results=self.top_k,
where=filter_conditions
)
# 转换为Document对象
docs = []
for i in range(len(results['ids'][0])):
doc = Document(
page_content=results['documents'][0][i],
metadata=results['metadatas'][0][i] or {}
)
docs.append(doc)
return docs
def hybrid_search(
self,
query: str,
keyword: Optional[str] = None,
filter_conditions: Optional[Dict] = None
) -> List[Document]:
"""混合检索(向量+关键词)"""
# 先执行向量搜索
vector_results = self.similarity_search(query, filter_conditions)
# 如果有关键词,进行过滤
if keyword:
filtered = [
doc for doc in vector_results
if keyword.lower() in doc.page_content.lower()
]
return filtered[:self.top_k]
return vector_results
def get_by_id(self, doc_id: str) -> Optional[Document]:
"""根据ID获取文档"""
result = self.chroma_db.collection.get(ids=[doc_id])
if not result['documents']:
return None
return Document(
page_content=result['documents'][0],
metadata=result['metadatas'][0] or {}
)
FastAPI接口封装
跟目录下新建api/retrieval_api.py
import logging
from typing import List, Optional
from fastapi import FastAPI, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from datetime import datetime
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(
title="Construction QA Retrieval API",
description="建筑工程知识库检索接口",
version="1.0.0",
openapi_tags=[{
"name": "检索",
"description": "知识库检索相关接口"
}]
)
# 允许跨域
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# --- 数据模型 ---
class DocumentMetadata(BaseModel):
"""文档元数据模型"""
source: Optional[str] = Field(None, example="GB/T 50081-2019")
page: Optional[int] = Field(None, example=12)
timestamp: Optional[datetime] = Field(None, example="2023-01-01T00:00:00")
class DocumentResponse(BaseModel):
"""检索结果模型"""
id: str = Field(..., example="doc_123")
content: str = Field(..., example="混凝土强度检测标准...")
metadata: DocumentMetadata
score: float = Field(..., ge=0, le=1, example=0.85)
class QueryRequest(BaseModel):
"""查询请求模型"""
query: str = Field(..., min_length=1, example="混凝土强度标准")
top_k: Optional[int] = Field(5, gt=0, le=20, example=3)
keyword_filter: Optional[str] = Field(None, example="钢筋")
metadata_filter: Optional[dict] = Field(None, example={"source": "GB"})
class HealthCheckResponse(BaseModel):
"""健康检查响应"""
status: str = Field(..., example="OK")
version: str = Field(..., example="1.0.0")
# --- 核心逻辑 ---
def initialize_retriever():
"""初始化检索器(实际项目应使用依赖注入)"""
from knowledge_base.retrieval.vector_retriever import VectorRetriever
from core.utils.embedding_utils import load_embedding_model
try:
return VectorRetriever(
persist_path="data/vector_db",
embedding_function=load_embedding_model(),
top_k=10
)
except Exception as e:
logger.error(f"检索器初始化失败: {str(e)}")
raise
retriever = initialize_retriever()
# --- API端点 ---
@app.get("/health", response_model=HealthCheckResponse, tags=["系统"])
async def health_check():
"""服务健康检查"""
return {
"status": "OK",
"version": "1.0.0"
}
@app.post("/search",
response_model=List[DocumentResponse],
tags=["检索"],
summary="文档检索",
responses={
200: {"description": "成功返回检索结果"},
400: {"description": "无效请求参数"},
500: {"description": "服务器内部错误"}
})
async def search_documents(request: QueryRequest):
"""
执行文档检索,支持以下方式:
- 纯向量检索
- 关键词过滤检索
- 元数据过滤检索
"""
try:
logger.info(f"收到检索请求: {request.dict()}")
# 参数验证
if len(request.query) > 500:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="查询文本过长(最大500字符)"
)
# 执行检索
if request.keyword_filter or request.metadata_filter:
docs = retriever.hybrid_search(
query=request.query,
keyword=request.keyword_filter,
filter_conditions=request.metadata_filter
)
else:
docs = retriever.similarity_search(
query=request.query,
filter_conditions=request.metadata_filter
)
# 格式化结果
results = []
for doc in docs[:request.top_k]:
if not hasattr(doc, 'metadata'):
doc.metadata = {}
results.append({
"id": str(hash(doc.page_content)),
"content": doc.page_content,
"metadata": doc.metadata,
"score": doc.metadata.get("score", 0.0)
})
logger.info(f"返回 {len(results)} 条结果")
return results
except HTTPException:
raise
except Exception as e:
logger.error(f"检索失败: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="检索服务暂时不可用"
)
@app.get("/document/{doc_id}",
response_model=DocumentResponse,
tags=["检索"],
summary="按ID获取文档")
async def get_document(doc_id: str):
"""通过文档ID获取完整内容"""
try:
doc = retriever.get_by_id(doc_id)
if not doc:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="文档不存在"
)
return {
"id": doc_id,
"content": doc.page_content,
"metadata": doc.metadata or {},
"score": 1.0
}
except Exception as e:
logger.error(f"获取文档失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="文档获取失败"
)
# --- 启动配置 ---
if __name__ == "__main__":
import uvicorn
uvicorn.run(
app,
host="0.0.0.0",
port=8000,
log_config={
"version": 1,
"disable_existing_loggers": False,
"handlers": {
"console": {
"class": "logging.StreamHandler",
"level": "INFO",
"formatter": "default"
}
},
"formatters": {
"default": {
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
}
},
"root": {
"handlers": ["console"],
"level": "INFO"
}
}
)
测试用例:tests/unit/test_embedding_utils.py
import pytest
from unittest.mock import patch
from core.utils.embedding_utils import load_embedding_model
class TestEmbeddingUtils:
@patch('langchain.embeddings.HuggingFaceEmbeddings')
def test_load_huggingface(self, mock_embeddings):
"""测试加载HuggingFace模型"""
model = load_embedding_model(model_type="huggingface")
mock_embeddings.assert_called_once()
@patch.dict('os.environ', {'OPENAI_API_KEY': 'test_key'})
@patch('langchain.embeddings.OpenAIEmbeddings')
def test_load_openai(self, mock_embeddings):
"""测试加载OpenAI模型"""
model = load_embedding_model(model_type="openai")
mock_embeddings.assert_called_once_with(
model="text-embedding-3-small",
deployment=None,
openai_api_key="test_key"
)
def test_invalid_model_type(self):
"""测试无效模型类型"""
with pytest.raises(ValueError):
load_embedding_model(model_type="invalid_type")