基于DeepSeek-R1 的RAG智能问答系统开发攻略
RAG为何成为技术热点?
RAG(Retrieval-Augmented Generation,检索增强生成) 结合了信息检索与生成模型的优势,能有效解决大模型“幻觉”问题,在智能客服、知识管理等领域应用广泛。本文将手把手教你搭建一个支持多格式文档上传、本地化部署的RAG系统,完整代码已开源,适合有一定Python基础的技术爱好者。
项目架构与核心技术
系统架构图
前端:Gradio交互界面 / FastAPI服务。
核心模块:文档加载 → 向量索引 → 检索增强生成。
底层支持:FAISS向量库、SentenceTransformer嵌入模型、本地LLM(如DeepSeek/ChatGLM)。
技术栈
向量检索:FAISS(Facebook开源的相似性搜索库)。
嵌入模型:BAAI/bge-m3(支持中英文的高效文本向量化)。
生成模型:DeepSeek-R1-1.5B(轻量级开源模型,适合本地部署)。
开发框架:Gradio(快速构建UI)、FastAPI(高性能API服务)。
环境搭建与模型准备
项目结构
Rag-System/
├─ model/ # 模型存放文件夹
│ ├─ BAAI/bge-m3/ # 本地Embedding模型的文件夹
│ └─ DeepSeek-R1-Distill-Qwen-1.5B/ # 本地LLM大语言模型的文件夹
│ └─ ChatGlm3-6B/ # 本地LLM大语言模型的文件夹
├─ knowledge_base/
│ ├─ some_text.txt # 你本地知识库中的各种文件(txt/pdf/docx等)
│ ├─ identity.md # 自我认知文件
│ └─ ...
├─ cache/
│ └─ faiss_index/ # FAISS 索引缓存目录(会自动生成)
├─ icon/
│ ├─ bot.png # agent 头像
│ └─ user.png # 用户 头像
├─ config.py # 配置文件
├─ loader.py # 索引构建
├─ main.py # 多线程加载文档
├─ rag.py # 文档 检索、回答生成
├─ app.py # Gradio交互界面
├─ api.py # FastAPI服务端,提供REST接口
└── ... 其他文件 ...
在根目录创建 cache、icon、knowledge_base、model文件夹。
cache:用于存放 FAISS 索引缓存目录。
icon:用于存放 agent 和用户头像。
knowledge_base:用于存放知识库文件(里面要有一个名为 identity.md 的自我认知文件)。
model:用于存放模型文件。
Anaconda 环境搭建以及
1. 打开 Anaconda 重新创建一个环境。
2. 打开执行终端。
3. 导航到根目录。
4. 在根目录创建一个名为 requirements.txt 的文件,这个是依赖安装文件。
requirements.txt 内容。
langchain
openai
faiss-cpu
transformers
sentence-transformers
gradio
fastapi
uvicorn
python-multipart
pypdf
python-docx
pandas
openpyxl
python-pptx
fastapi
uvicorn
sentencepiece
5. 依赖安装
pip install -r requirements.txt
没有报错就证明没问题,有报错就看一下是哪个,单独下载就行。
6. CUDA 安装
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
模型准备
需要下载三个模型分别是:bge-m3、ChatGlm3-6B、DeepSeek-R1-Distill-Qwen-1.5B。
在model 文件夹下建立这样的三个文件夹,以便于存放模型文件。
BAAI/bge-m3 模型下载
BAAI/bge-m3:支持中英文的高效文本向量化
1. 在命令行键入:python 进入python 模式
2. 直接拷贝代码到命令行,等待模型下载。
from sentence_transformers import SentenceTransformer
model_name = "BAAI/bge-m3"
embedding_model = SentenceTransformer(model_name)
# 保存到本地
embedding_model.save("model/BAAI/bge-m3")
3. 这就是下载好的 bge-m3 模型。
4. Ctrl + Z 然后回车 返回 Anaconda 终端。
DeepSeek-R1-1.5B & ChatGlm3-6B 模型下载
下载地址: DeepSeek-R1-Distill-Qwen-1.5B
下载地址: ChatGlm3-6B
1. 模型下载。
2. DeepSeek-R1-Distill-Qwen-1.5B 模型 存放地址。
3. ChatGlm3-6B 存放地址。
核心代码解析
上传文档与索引重建
代码逻辑(app.py → upload_files函数):
用户上传文件后,文件会被移动到 knowledge_base 目录。
调用 rebuild_index()(main.py)重新构建FAISS索引。
问答流程
检索阶段(rag.py → retrieve_top_k_documents):
用户问题 → 向量化 → FAISS检索最相关的3个文档。
生成阶段(rag.py → generate_answer):
拼接检索到的文档和身份设定(identity.md)→ 生成回答(流式输出)。
模型切换
代码逻辑(app.py → switch_model函数):
从 config.MODEL_PATHS 加载新模型,更新全局变量 tokenizer 和 model。
文档加载与索引构建(loader.py & main.py)
# loader.py:多线程加载文档(关键函数)
def load_documents(directory):
with ThreadPoolExecutor(max_workers=4) as executor:
futures = [executor.submit(load_file, filepath, filename)
for filename in os.listdir(directory)]
# 处理结果...
# main.py:构建FAISS索引
def rebuild_index():
embeddings = embedding_model.encode(texts) # 文本向量化
index = faiss.IndexFlatL2(dim) # 创建L2距离索引
index.add(embeddings) # 添加向量
faiss.write_index(index, "cache/faiss_index/docs.index") # 保存索引
检索增强生成(rag.py)
def generate_answer(query):
# 1. 检索Top-K文档
related_docs = retrieve_top_k_documents(query, k=3)
# 2. 拼接Prompt(身份设定+检索内容)
prompt = f"""
[系统角色设定]:{identity_content}
[检索结果]:{related_docs}
问题:{query}
回答:
"""
# 3. 流式生成回答
for token in model.generate(**inputs, stream=True):
yield token # 逐词输出,提升用户体验
Gradio交互界面(app.py)
# 文件上传与索引更新
def upload_files(files, chatbot):
shutil.move(file.name, "knowledge_base/new_file.pdf") # 移动文件
rebuild_index() # 触发索引重建
chatbot.append(("📁 上传完成", "索引已更新!"))
# 流式对话演示
with gr.Blocks() as interface:
chatbot = gr.Chatbot(height=600)
msg_input = gr.Textbox(label="请输入问题")
submit_btn.click(chat_with_rag, [msg_input, chatbot], [chatbot])
相关完整代码
config.py
配置文件,定义模型路径、索引目录、知识库位置等参数。
# config.py
from pathlib import Path
import os
# 基础目录,默认为当前文件所在目录
BASE_DIR = Path(__file__).parent
class Config:
# 嵌入模型路径:你可以替换成更适合中文的模型,例如 GanymedeNil/text2vec-base-chinese
EMBEDDING_MODEL = str(BASE_DIR / "model" / "BAAI" / "bge-m3")
# 默认模型相关参数
DEFAULT_MAX_LENGTH = 4096
CHUNK_SIZE = 1000
OVERLAP = 200
# FAISS 索引缓存目录
FAISS_CACHE = BASE_DIR / "cache" / "faiss_index"
# 可用的 LLM 模型路径,存放在 MODEL_PATHS 字典中
MODEL_PATHS = {
'DeepSeek-R1-1.5B': os.path.abspath(BASE_DIR / "model" / "DeepSeek-R1-Distill-Qwen-1.5B"),
'ChatGLM3-6B': os.path.abspath(BASE_DIR / "model" / "ChatGlm3-6B")
}
# 默认使用的 LLM 模型(可以在此更改为其他键)
DEFAULT_LLM_MODEL = 'DeepSeek-R1-1.5B'
# 从 MODEL_PATHS 中取出默认模型路径,并转换为字符串
LLM_MODEL_PATH = str(MODEL_PATHS[DEFAULT_LLM_MODEL])
# 参考文档文件夹,用于存放系统的身份设定、常见文档等
REFERENCE_FOLDER = BASE_DIR / "knowledge_base"
# 确保目录存在
REFERENCE_FOLDER.mkdir(parents=True, exist_ok=True)
# 系统身份设定文件(例如 identity.md),放在参考文档文件夹中
IDENTITY_FILE = REFERENCE_FOLDER / "identity.md"
# 其他参数设置
MAX_HISTORY = 5
STREAM_SEGMENT_SIZE = 5
STREAM_DELAY = 0.1
ADMIN_PASSWORD = os.getenv("ADMIN_PASSWORD", "Maddie")
def __init__(self):
# 确保缓存目录和参考文档文件夹存在
self.FAISS_CACHE.mkdir(parents=True, exist_ok=True)
self.REFERENCE_FOLDER.mkdir(parents=True, exist_ok=True)
# 实例化配置对象,后续代码中直接导入 cfg 使用
config = Config()
loader.py
多线程加载文档(PDF/DOCX/TXT等),返回文本内容。
# loader.py
import os
import logging
from typing import List
import pandas as pd
from pypdf import PdfReader
from docx import Document
from pptx import Presentation
from concurrent.futures import ThreadPoolExecutor, as_completed
# ========================
# 模块初始化配置
# 配置日志输出
# ========================
# 将第三方库(httpcore, urllib3)的日志等级设为WARNING,避免输出过多调试信息
logging.getLogger("httpcore").setLevel(logging.WARNING)
logging.getLogger("urllib3").setLevel(logging.WARNING)
# 获取当前模块的日志记录器
logger = logging.getLogger(__name__)
# 基本日志配置,设置日志记录等级和输出格式
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
"""
加载并读取TXT文本文件内容,返回文件中的全部文本。
如果读取过程中出现异常,会记录错误并返回空字符串。
"""
def load_txt(filepath):
try:
with open(filepath, "r", encoding="utf-8") as file:
return file.read()
except Exception as e:
logging.error(f"读取TXT文件 {filepath} 失败: {e}")
return ""
"""
使用PdfReader读取PDF文件中的文本内容。
将PDF的每个页面提取出的文本进行拼接并返回。
如果读取过程中出现异常,会记录错误并返回空字符串。
"""
def load_pdf(filepath):
try:
reader = PdfReader(filepath)
text = ""
for page in reader.pages:
page_text = page.extract_text()
if page_text:
text += page_text
return text
except Exception as e:
logging.error(f"读取PDF文件 {filepath} 失败: {e}")
return ""
"""
使用docx.Document读取DOCX文件中的段落文本。
将所有段落的文本用换行符拼接起来并返回。
如果读取过程中出现异常,会记录错误并返回空字符串。
"""
def load_docx(filepath):
try:
doc = Document(filepath)
return "\n".join([para.text for para in doc.paragraphs])
except Exception as e:
logging.error(f"读取DOCX文件 {filepath} 失败: {e}")
return ""
"""
使用pandas读取Excel文件,并将DataFrame转换为制表符分隔的字符串返回。
如果读取过程中出现异常,会记录错误并返回空字符串。
"""
def load_excel(filepath):
try:
df = pd.read_excel(filepath, engine='openpyxl')
return df_to_text(df)
except Exception as e:
logging.error(f"读取Excel文件 {filepath} 失败: {e}")
return ""
"""
使用pptx.Presentation读取PPTX文件。
遍历每一页(slide)和每个文本框(shape)的段落,将其文本内容按换行符拼接并返回。
如果读取过程中出现异常,会记录错误并返回空字符串。
"""
def load_pptx(filepath):
try:
prs = Presentation(filepath)
text = ""
for slide in prs.slides:
for shape in slide.shapes:
if shape.has_text_frame:
for paragraph in shape.text_frame.paragraphs:
text += paragraph.text + "\n"
return text
except Exception as e:
logging.error(f"读取PPTX文件 {filepath} 失败: {e}")
return ""
"""
将DataFrame转换为以制表符(\t)分隔的CSV格式字符串,并返回。
不包含索引(index)列。
"""
def df_to_text(df):
return df.to_csv(index=False, sep='\t')
# 扩展名与对应的加载函数映射字典
# 根据文件后缀名选择合适的读取函数
LOADER_FUNCTIONS = {
".txt": load_txt,
".md": load_txt, # Markdown 文件按文本文件处理,复用load_txt
".pdf": load_pdf,
".docx": load_docx,
".xlsx": load_excel,
".xls": load_excel,
".pptx": load_pptx,
}
"""
根据文件名的扩展名(ext)从LOADER_FUNCTIONS中查找对应的加载函数进行读取。
如果不支持的文件类型,则记录警告并返回空字符串。
"""
def load_file(filepath, filename):
ext = os.path.splitext(filename)[1].lower() # 提取文件后缀并转换为小写
loader = LOADER_FUNCTIONS.get(ext)
if loader:
return loader(filepath)
else:
logging.warning(f"不支持的文件类型: {filename}")
return ""
"""
从指定目录(directory)中批量加载所有文件的内容,并返回一个包含字典的列表。
每个字典结构为:{"filename": 文件名, "content": 文件内容}
其中使用线程池(ThreadPoolExecutor)并行地读取文件以提高效率。
"""
def load_documents(directory):
docs = []
files = os.listdir(directory)
logging.info(f"在目录 {directory} 找到 {len(files)} 个文件")
# 使用线程池并发加载文件,max_workers=4 表示最多4个并发线程
with ThreadPoolExecutor(max_workers=4) as executor:
future_to_filename = {}
# 将每个文件的加载任务提交给线程池
for filename in files:
filepath = os.path.join(directory, filename)
if os.path.isfile(filepath): # 只处理文件,不处理子目录
future = executor.submit(load_file, filepath, filename)
future_to_filename[future] = filename
# 收集加载结果
for future in as_completed(future_to_filename):
filename = future_to_filename[future]
try:
content = future.result() # 获取文件读取结果
if content:
# 只有当读取到非空内容时,才追加到docs列表
docs.append({"filename": filename, "content": content})
else:
logging.info(f"文件 {filename} 没有加载到内容")
except Exception as e:
logging.error(f"加载文件 {filename} 时出错: {e}")
return docs
main.py
重建FAISS索引,将文档内容向量化存储。
# main.py
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import os
import pickle
from loader import load_documents # 使用改进后的 loader 模块
from config import config
import logging
# ========================
# 模块初始化配置
# 配置日志输出
# ========================
logging.getLogger("httpcore").setLevel(logging.WARNING)
logging.getLogger("urllib3").setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
# 加载预训练的嵌入模型(Embedding Model)
embedding_model_path = os.path.abspath(config.EMBEDDING_MODEL)
embedding_model = SentenceTransformer(embedding_model_path)
def rebuild_index():
"""重新加载所有文档,并重建 FAISS 索引"""
print("🔄 开始重建 FAISS 索引...")
# 使用配置中的知识库目录
knowledge_dir = config.REFERENCE_FOLDER # 这里改为 `REFERENCE_FOLDER`
# 通过 loader 并发加载所有文档
docs = list(load_documents(knowledge_dir))
if not docs:
print("⚠️ 没有找到文档,索引未更新。")
return "⚠️ 没有找到文档,索引未更新。"
print(f"📂 加载了 {len(docs)} 个文档") # 打印文档数量
# 提取每篇文档的内容和文件名
texts = [doc["content"] for doc in docs]
filenames = [doc["filename"] for doc in docs]
# 将文档内容转化为向量表示
embeddings = np.array(embedding_model.encode(texts)) # 确保是 NumPy 数组
# 使用 FAISS 建立索引
dim = embeddings.shape[1] # 向量维度
index = faiss.IndexFlatL2(dim)
index.add(embeddings)
# 保存 FAISS 索引和文档文件名列表
faiss.write_index(index, str(config.FAISS_CACHE / "docs.index"))
with open(str(config.FAISS_CACHE / "filenames.pkl"), "wb") as f:
pickle.dump(filenames, f)
print("✅ FAISS 索引已成功重建!")
return "✅ FAISS 索引已成功重建!"
# 如果 `main.py` 直接运行,则自动创建索引
if __name__ == "__main__":
rebuild_index()
rag.py
核心逻辑:检索文档、生成回答、总结/改写文本。
# rag.py
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false" # 禁用 Hugging Face 的并行警告
os.environ["BITSANDBYTES_NOWELCOME"] = "1" # 禁止 bitsandbytes 欢迎信息
os.environ["WANDB_DISABLED"] = "true" # 禁用 wandb 日志(如果有)
import faiss
import numpy as np
import pickle
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from sentence_transformers import SentenceTransformer
import pandas as pd
from config import config
import logging
from docx import Document # 以防后面需要 docx 解析
##############################################################################
# 1. 加载模型与索引
##############################################################################
# 加载本地 Embedding 模型(SentenceTransformer),用于 FAISS 检索
embedding_model_path = os.path.abspath(config.EMBEDDING_MODEL)
embedding_model = SentenceTransformer(embedding_model_path)
# 加载 FAISS 索引
index = faiss.read_index(str(config.FAISS_CACHE / "docs.index"))
print(f"✅ FAISS 索引维度:{index.d}")
# 加载文件名列表
with open(str(config.FAISS_CACHE / "filenames.pkl"), "rb") as f:
filenames = pickle.load(f)
# 加载本地 LLM,并使用 GPU 加速
llm_model_path = os.path.abspath(config.LLM_MODEL_PATH)
tokenizer = AutoTokenizer.from_pretrained(llm_model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
llm_model_path,
torch_dtype=torch.float16,
device_map="auto"
)
# 如果想手动指定 device,也可用 model.to(config.DEVICE),这里 device_map="auto" 通常就能工作
print("✅ 本地 LLM 模型加载完成!")
##############################################################################
# 2. 文档加载函数
##############################################################################
def load_document_content(filename):
"""
根据文件名加载文本内容。identity.md 作为 AI 角色设定。
"""
# 使用配置中的知识库目录
file_path = os.path.join(config.REFERENCE_FOLDER, filename)
# 特殊处理 identity.md
if filename == "identity.md":
with open(file_path, "r", encoding="utf-8") as file:
return f"[系统角色提示]\n{file.read()}\n\n"
if filename.endswith(".txt") or filename.endswith(".md"):
with open(file_path, "r", encoding="utf-8") as file:
return file.read()
elif filename.endswith(".pdf"):
from pypdf import PdfReader
reader = PdfReader(file_path)
return "\n".join([page.extract_text() for page in reader.pages if page.extract_text()])
elif filename.endswith(".docx"):
# 利用 python-docx 读取
doc = Document(file_path)
return "\n".join([para.text for para in doc.paragraphs])
elif filename.endswith(".xlsx") or filename.endswith(".csv"):
try:
df = pd.read_excel(file_path) if filename.endswith(".xlsx") else pd.read_csv(file_path)
return df.to_string(index=False)
except Exception as e:
return f"❌ 读取 {filename} 时出错: {str(e)}"
elif filename.endswith(".pptx"):
try:
from pptx import Presentation
prs = Presentation(file_path)
text = []
for slide in prs.slides:
for shape in slide.shapes:
if shape.has_text_frame:
for paragraph in shape.text_frame.paragraphs:
text.append(paragraph.text)
return "\n".join(text)
except Exception as e:
return f"❌ 读取 {filename} 时出错: {str(e)}"
return "❌ 无法读取该文档格式:" + filename
##############################################################################
# 3. 检索函数
##############################################################################
def retrieve_top_k_documents(query, k=3):
"""
根据查询语句在索引中找到最相关的 k 个文档,并返回其内容(截取)。
"""
query_embedding = np.array(embedding_model.encode([query]))
print(f"✅ 查询向量维度:{query_embedding.shape[1]}")
_, idxs = index.search(query_embedding, k)
retrieved_docs = []
for i in idxs[0]:
filename = filenames[i]
content = load_document_content(filename)
# 这里可以只截取前1000字符,避免 Prompt 过长
retrieved_docs.append(f"📄【{filename}】\n{content[:100]}...")
return retrieved_docs
##############################################################################
# 4. 分块 (Chunk) + 摘要 / 改写
##############################################################################
def chunk_text(text, chunk_size=1000, overlap=200):
"""
将长文本分块,每块 chunk_size 个字符,并在块之间保留 overlap 个字符重叠,避免关键信息被切断。
"""
chunks = []
start = 0
text_len = len(text)
while start < text_len:
end = min(start + chunk_size, text_len) # 避免越界
chunk = text[start:end]
chunks.append(chunk)
start = end - overlap
if start < 0:
start = 0
if start >= text_len:
break
return chunks
def summarize_long_text(text):
"""
对长文本进行多段式摘要,然后合并。
"""
# 分块
text_chunks = chunk_text(text, chunk_size=1500, overlap=200)
chunk_summaries = []
# 逐块摘要
for idx, chunk in enumerate(text_chunks):
prompt = f"请阅读以下文本内容,并进行简要总结:\n{chunk}\n\n总结:"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
output = model.generate(
**inputs,
max_new_tokens=300,
temperature=0.7,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id
)
summary_chunk = tokenizer.decode(output[0], skip_special_tokens=True)
chunk_summaries.append(summary_chunk.strip())
# 合并所有块的摘要,再让模型做一次“总总结”
combined_summary = "\n".join(chunk_summaries)
final_prompt = f"以下是多个分块的总结,请将其合并为一个简洁的整体总结:\n{combined_summary}\n\n整体总结:"
final_inputs = tokenizer(final_prompt, return_tensors="pt").to(model.device)
final_output = model.generate(
**final_inputs,
max_new_tokens=500,
temperature=0.7,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id
)
final_summary = tokenizer.decode(final_output[0], skip_special_tokens=True)
return final_summary.strip()
def rewrite_long_text(text):
"""
对长文本进行“改写 / 润色”,与摘要类似的思路,先分块再合并。
"""
# 先粗暴示例,也可以根据需要拆分做多次改写
text_chunks = chunk_text(text, chunk_size=1500, overlap=200)
rewrite_results = []
for chunk in text_chunks:
prompt = f"请对以下文本进行语言润色或改写,使其更通顺、简洁:\n{chunk}\n\n改写后:"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
output = model.generate(
**inputs,
max_new_tokens=300,
temperature=0.7,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id
)
rewrite_chunk = tokenizer.decode(output[0], skip_special_tokens=True)
rewrite_results.append(rewrite_chunk.strip())
# 简单拼接,如果想再做最终合并,可以再来一次生成
return "\n".join(rewrite_results).strip()
##############################################################################
# 5. 最终回答:generate_answer
# - 如果用户问 "总结xx文件" 或 "改写xx文件":
# -> 直接找到文件内容做 summarize/rewrite
# - 否则做普通RAG问答
##############################################################################
def generate_answer(query):
# 1) 检索文档
related_docs = retrieve_top_k_documents(query, k=3)
# 2) 加载 identity.md(如果存在)
identity_content = ""
if "identity.md" in filenames:
identity_content = load_document_content("identity.md")
# 3) 拼接上下文
context = f"AI Identity/Persona\n{identity_content}\n\n【知识库检索结果】\n" + "\n\n".join(related_docs)
# 构造 Prompt
prompt = f"""
你是一名智能问答助手(扮演DeepSeek知识管家 Theodore 西奥-多尔),以下是你的身份描述和检索到的文档内容:
{context}
请遵守identity.md中的所有“沟通风格准则”和“特殊行为准则”,并根据文档做出回答。
如果在知识库中找不到答案,请回答“对不起,我在知识库中没有找到相关信息”。
请使用简洁且专业的口吻。
用户的问题:{query}
请直接给出简洁且专业的回答:
""".strip()
# 流式生成回答
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# 使用模型的流式生成参数(需transformers >=4.21.0)
output_stream = model.generate(
**inputs,
max_new_tokens=300,
temperature=0.7,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id,
do_sample=True
)
# 初始化输出
generated_text = ""
# 流式逐token输出
for token_id in output_stream[0]:
token_text = tokenizer.decode(token_id, skip_special_tokens=True)
generated_text += token_text
yield generated_text # 实时返回当前生成内容
# 完整的调试信息(生成完成后)
debug_info = (
"\n\n### 检索与推理过程\n\n"
f"**用户问题**: {query}\n\n"
f"**Prompt 内容**: \n{prompt}\n\n"
"——以上信息仅供调试或进阶查看——\n"
)
# 最终输出带有调试信息(加上</think>标记)
#print(f"{generated_text}</think>{debug_info}")
yield f"{generated_text}</think>{debug_info}"
##############################################################################
# 6. 简单函数:_simple_summarize / _simple_rewrite (对短文本)
##############################################################################
def _simple_summarize(text):
"""
对短文本做一次性摘要。如果文本不大,可直接用。
"""
prompt = f"请阅读以下文本并进行简要总结:\n{text}\n\n总结:"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
output = model.generate(
**inputs,
max_new_tokens=300,
temperature=0.7,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id
)
return tokenizer.decode(output[0], skip_special_tokens=True).strip()
def _simple_rewrite(text):
"""
对短文本做一次性改写。
"""
prompt = f"请对以下文本进行语言润色或改写,使其更通顺、简洁:\n{text}\n\n改写后:"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
output = model.generate(
**inputs,
max_new_tokens=300,
temperature=0.7,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id
)
return tokenizer.decode(output[0], skip_special_tokens=True).strip()
##############################################################################
# 7. 命令行交互入口
##############################################################################
import re
if __name__ == "__main__":
print("📚 RAG + Summarize/Rewrite 示例程序启动...")
try:
while True:
query = input("\n请输入您的问题(输入 'exit' 退出): ")
if query.lower().strip() == "exit":
print("\n👋 退出 RAG 系统,再见!")
break
# 初始化变量
bot_response = ""
debug_info = "暂无推理过程"
current_output_buffer = ""
# 流式处理生成器输出
stream = generate_answer(query)
for current_output in stream:
current_output_buffer += current_output
if "</think>" in current_output_buffer:
parts = current_output_buffer.split("</think>")
if len(parts) >= 2:
bot_response = parts[1].strip() # 模型的回答
if len(parts) > 2:
debug_info = parts[2].strip() # 调试信息
else:
debug_info = "暂无推理过程"
# 清空缓冲区
current_output_buffer = ""
# 打印流式输出的回答
print("\nAI 回答(流式输出):", bot_response, "\n")
else:
bot_response = parts[0].strip()
current_output_buffer = ""
else:
# 如果没有找到 </think>,继续积累输出
continue
# 如果缓冲区中还有内容,处理剩余部分
if current_output_buffer:
bot_response = current_output_buffer.strip()
print("\nAI 回答(流式输出):", bot_response, "\n")
except KeyboardInterrupt:
print("\n👋 检测到 Ctrl + C,退出 RAG 系统!")
""" 模型回答 一次性输出
if __name__ == "__main__":
print("📚 RAG + Summarize/Rewrite 示例程序启动...")
try:
while True:
query = input("\n请输入您的问题(输入 'exit' 退出): ")
if query.lower().strip() == "exit":
print("\n👋 退出 RAG 系统,再见!")
break
# 初始化变量
bot_response = ""
debug_info = "暂无推理过程"
# 流式处理生成器输出
stream = generate_answer(query)
for current_output in stream:
if "</think>" in current_output:
parts = current_output.split("</think>")
if len(parts) >= 2:
bot_response = parts[1].strip() # 模型的回答
if len(parts) > 2:
debug_info = parts[2].strip() # 调试信息
else:
bot_response = parts[0].strip()
else:
bot_response = current_output.strip()
# 输出最终回答
print("\nAI 回答:", bot_response, "\n")
except KeyboardInterrupt:
print("\n👋 检测到 Ctrl + C,退出 RAG 系统!")
"""
app.py
Gradio交互界面,支持上传文档、提问、切换模型、监控系统资源。
# app.py
import gradio as gr
import os
import torch
import psutil # 用于系统监控
from rag import generate_answer
from config import config
from loader import load_documents
from transformers import AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer
import json
from main import rebuild_index # ✅ 直接导入 `rebuild_index`
import shutil
import time
# 头像路径
USER_ICON_PATH = "icon/user.png"
BOT_ICON_PATH = "icon/bot.png"
# 1️ 加载嵌入模型(SentenceTransformer)
embedding_model = SentenceTransformer(config.EMBEDDING_MODEL)
# 2️ 加载 LLM 模型
def load_llm_model(model_name):
"""动态加载 LLM"""
model_path = str(config.MODEL_PATHS.get(model_name, config.LLM_MODEL_PATH))
# ✅ 确保 `trust_remote_code=True`
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.float16, device_map="auto")
return tokenizer, model
# 预加载默认 LLM
current_model_name = config.DEFAULT_LLM_MODEL
tokenizer, model = load_llm_model(current_model_name)
# 3️ 处理文件上传
def upload_files(files, chatbot):
"""上传文件,更新 Chatbot 并确保返回符合 Gradio Chatbot 格式"""
if not isinstance(files, list):
files = [files]
saved_files = []
failed_files = []
for file in files:
try:
original_filename = file.orig_name if hasattr(file, "orig_name") else os.path.basename(file.name)
dest_path = os.path.join(config.REFERENCE_FOLDER, original_filename)
# ✅ 将临时路径文件移动到 `knowledge_base` 目录
shutil.move(file.name, dest_path)
saved_files.append(original_filename)
except Exception as e:
failed_files.append(original_filename)
print(f"❌ 上传失败: {original_filename}, 错误: {e}")
# ✅ 只有至少有一个文件上传成功,才重建索引
if saved_files:
print("🔄 至少一个文件上传成功,开始重建索引...")
index_message = rebuild_index()
else:
index_message = "⚠️ 所有文件上传失败,索引未更新。"
# ✅ 构建返回消息
message = f"📂 上传成功 {len(saved_files)} 个文件: {', '.join(saved_files)}"
if failed_files:
message += f"\n❌ 上传失败 {len(failed_files)} 个文件: {', '.join(failed_files)}"
message += f"\n{index_message}"
print(message)
# ✅ 让 Chatbot 显示消息(格式必须是 `[(用户输入, 机器人回复)]`)
chatbot.append(("📁 文件上传", message))
return chatbot
# 4️ 处理对话
def chat_with_rag(question, chatbot, max_tokens, temperature, top_p, show_debug, topk_retrieval, dist_threshold):
chatbot.append((question, ""))
start_time = time.time()
stream = generate_answer(question)
bot_response = ""
for current_output in stream:
if "</think>" in current_output:
parts = current_output.split("</think>")
if len(parts) >= 2:
bot_response = parts[1].strip() # 模型的回答
debug_info = parts[2].strip() if len(parts) > 2 else "暂无推理过程"
else:
bot_response = parts[0].strip()
debug_info = "暂无推理过程"
else:
bot_response = current_output.strip()
chatbot[-1] = (question, bot_response)
yield "", chatbot
# 推理时间计算
elapsed_time = time.time() - start_time
elapsed_str = f"🔍 点击查看推理过程,耗时 {elapsed_time:.2f} 秒 ⌄"
# ✅ 仅对推理过程的字体进行淡化(去除对 bot_response 的额外颜色修改)
if show_debug and debug_info:
bot_response = (
f"<details>"
f"<summary style='color:#888;font-size:12px;'>{elapsed_str}</summary>"
f"<div style='color:#ccc;background:#f5f5f5;padding:10px;border-radius:5px;'>\n\n"
f"{debug_info}\n"
f"</div></details>\n\n"
f"{bot_response}" # ✅ 去掉了 bot_response 外部的额外颜色设置
)
chatbot[-1] = (question, bot_response)
yield "", chatbot
# 5️ 切换 LLM 模型
def switch_model(new_model):
"""切换 LLM 模型"""
global tokenizer, model, current_model_name
tokenizer, model = load_llm_model(new_model)
current_model_name = new_model
return f"✅ 已切换到 {new_model} 模型"
# 6️ 系统监控
def system_diagnosis():
"""返回 CPU、RAM、GPU 资源使用情况"""
cpu_usage = psutil.cpu_percent()
ram_usage = psutil.virtual_memory().percent
gpu_usage = torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0
return {"CPU 使用率": f"{cpu_usage}%", "内存使用率": f"{ram_usage}%", "GPU 占用": f"{gpu_usage:.2f}GB"}
# 7️ 导出 & 导入对话历史
def export_chat_history(chatbot):
"""导出对话历史为 JSON 文件"""
history = [{"用户": msg[0], "AI": msg[1]} for msg in chatbot]
file_path = "chat_history.json"
with open(file_path, "w", encoding="utf-8") as f:
json.dump(history, f, ensure_ascii=False, indent=2)
return file_path
def import_chat_history(file):
"""导入对话历史"""
if file is None:
return []
with open(file.name, "r", encoding="utf-8") as f:
history = json.load(f)
return [(msg["用户"], msg["AI"]) for msg in history]
# ========================
# 🚀 Gradio 界面构建
# ========================
def create_gradio_interface():
"""构建交互界面"""
theme = gr.themes.Default(
primary_hue="orange",
secondary_hue="blue"
)
with gr.Blocks(theme=theme, title="DeepSeek RAG System 2.0") as interface:
gr.Markdown("# 🔍 DeepSeek RAG 知识管理系统 (改进版)")
# 对话区
chatbot = gr.Chatbot(
value=[(None, "您好!我是 Theodore(西奥-多尔),您的智能助手 🚀")],
height=680,
avatar_images=(USER_ICON_PATH, BOT_ICON_PATH)
)
msg_input = gr.Textbox(placeholder="输入您的问题...", lines=3)
with gr.Row():
submit_btn = gr.Button("💬 发送", variant="primary")
os.environ["GRADIO_MAX_FILE_SIZE"] = "100mb"
upload_btn = gr.UploadButton(
"📁 上传文档",
file_types=[".pdf", ".docx", ".txt", ".md", ".pptx", ".xlsx"],
file_count="multiple"
)
clear_btn = gr.Button("🔄 清空对话")
# 💻 系统监控
with gr.Accordion("💻 系统监控", open=False):
gr.Markdown("### 实时系统指标")
diagnose_btn = gr.Button("🔄 刷新状态")
status_panel = gr.JSON(label="系统状态", value={"状态": "正在获取..."})
interface.load(system_diagnosis, inputs=None, outputs=status_panel)
# 🤖 模型管理
with gr.Accordion("🤖 模型管理", open=False):
model_selector = gr.Dropdown(
label="选择模型",
choices=list(config.MODEL_PATHS.keys()),
value=current_model_name
)
model_status = gr.Textbox(label="模型状态", interactive=False, value="正在初始化模型...")
interface.load(lambda: switch_model(current_model_name), inputs=None, outputs=model_status)
# 💬 对话历史
with gr.Accordion("💬 对话历史", open=False):
export_btn = gr.Button("导出历史")
import_btn = gr.UploadButton("导入历史", file_types=[".json"])
export_btn.click(export_chat_history, inputs=chatbot, outputs=gr.File())
import_btn.upload(import_chat_history, inputs=import_btn, outputs=chatbot)
# 📊 生成参数
with gr.Accordion("📊 生成参数", open=False):
max_tokens = gr.Slider(128, 4096, value=512, label="生成长度限制")
temperature = gr.Slider(0.1, 1.0, value=0.7, label="创造性")
top_p = gr.Slider(0.1, 1.0, value=0.9, label="核心采样")
topk_retrieval = gr.Slider(1, 10, value=3, step=1, label="检索文档数量 top_k")
dist_threshold = gr.Slider(0.0, 1.0, value=0.3, step=0.05, label="检索距离阈值")
show_debug = gr.Checkbox(label="显示推理过程", value=True)
# 绑定事件
msg_input.submit(
chat_with_rag,
inputs=[msg_input, chatbot, max_tokens, temperature, top_p, show_debug, topk_retrieval, dist_threshold],
outputs=[msg_input, chatbot]
)
submit_btn.click(
chat_with_rag,
inputs=[msg_input, chatbot, max_tokens, temperature, top_p, show_debug, topk_retrieval, dist_threshold],
outputs=[msg_input, chatbot]
)
clear_btn.click(lambda: [(None, "对话已清空")], outputs=chatbot)
upload_btn.upload(upload_files, inputs=[upload_btn, chatbot], outputs=[chatbot])
diagnose_btn.click(system_diagnosis, outputs=status_panel)
model_selector.change(switch_model, inputs=model_selector, outputs=model_status)
return interface
# 启动 Gradio
if __name__ == "__main__":
interface = create_gradio_interface()
interface.launch(server_name="127.0.0.1", server_port=7860,share=True)
快速启动以及功能说明
文档向量生成
文档向量生成:
将示例文档放入 knowledge_base 文件夹,运行以下命令构建索引:
python main.py
直接执行就会把索引构建。
会在 cahe\faiss_index 文件夹生成这两个文件。
docs.index:FAISS 索引。
filenames.pkl:文档文件名列表。
API 服务访问
启动API服务:
python api.py
访问 http://localhost:8000/ask 测试API接口。
这样显示就证明 API 正常启动。
API 测试
便捷测试工具: Apifox 网址
1. 在Apifox 新建一个项目。
2. 创建一个项目名称。
3. 新建快捷请求。
4. 请求类型更改为Post。
5. 填入请求地址以及更改 Headers 类型和参数。
6. 更改 Body 并把raw 值添加:{"question": "你是谁"}。
7. 保存当前快捷请求。
8. 保存当前快捷请求名称以及 存储根目录。
9. 点击发送测试当前接口是否正常返回,下面是正常状态。
10. 终端返回 200 OK 证明当前API 可正常访问。
11. Ctrl+C 退出 API。
命令行 交互访问
启动 命令行交互界面
python rag.py
如果出现这个错误就执行:conda install faiss
下载完毕会这样显示。
再次执行 python rag.py
直接输出问题就行,就可以正常问答了。
等待一会就会输出答案了,目前是流式输出,想要一次性输出的话可以去 rag.py “命令行交互入口”替换。
当然退出也是 Ctrl+C 或直接键入exit 回车退出。
Gradio 服务访问
启动Gradio界面
python app.py
访问 http://localhost:7860 使用交互界面。
这样显示就证明可以正常使用,如果没有自动跳出浏览器
就在浏览器地址栏输出 http://localhost:7860 进行访问。
模型对话功能
默认显示界面。
在Textbox 输入问题点击发送就行模型问答。
文件上传功能
点击上传文档功能,选择想要上传的文件,目前支持.txt、.md、.pdf、.docx、.pptx、.xlsx文件。
上传成功会在终端显示索引构建完成。
当然 Gradio 也会有相应的显示。
清空对话功能
点击清空对话按钮会清空当前对话记录。
系统监控功能
打开系统监控会输出:"CPU 使用率","内存使用率","GPU 占用" 信息,可刷新状态。
模型切换功能
打开会展示当前对话模型。
可切换当前模型列表中的模型 模型列表在:config.py MODEL_PATHS下添加。
对话历史功能
可选择导出历史和导入历史:
导出历史:导出当前对话历史。
导入历史:将以往的对话记录进行导入,进行持续对话。
这是导出的json 数据,由于刚才点击了清空对话,所以没有记录。
生成参数功能
具体参数含义(:
1. 生成长度限制(:控制生成回答的最大长度(如最多生成多少字/词)。
调整建议:
回答太短 → 增大数值、回答冗余 → 减小数值。
2. 创造性(:调节生成文本的随机性和多样性(类似 temperature 参数)。
调整建议:
需要严谨回答(如技术文档)→ 降低(如 0.3)
需要创意内容(如故事生成)→ 提高(如 0.9)
3. 核心采样(:通常指 top-p 采样,从概率最高的词中采样,平衡生成质量与多样性(例如 top_p=0.9 表示仅从概率最高的 90% 词汇中选择)。
调整建议:
需要精准回答 → 降低(如 0.7)
需要多样化表达 → 提高(如 0.95)
4. 检索文档数量top_k (:控制每次检索时返回的最相关文档数量。
调整建议:
需要更全面的上下文 → 增大(如 top_k=5)
需要更精准的回答 → 减小(如 top_k=2)
6. 检索距离阈值(:控制检索结果的相似度阈值,仅返回与问题向量距离小于该阈值的文档。
调整建议:
需要更严格的匹配 → 降低(如 0.2)
需要更宽松的匹配 → 提高(如 0.5)
7. 显示推理过程(:控制是否在回答中显示推理过程的详细信息(如检索到的文档、模型生成逻辑等)。
调整建议:
调试或开发阶段 → 开启(便于分析系统行为)
生产环境或普通用户使用 → 关闭(提升用户体验)
常见问题以及使用技巧
常见问题
1. 上传PDF后检索无结果:
检查PyPDF版本,尝试pypdf==3.17。
2. 回答内容重复:
Top-K值过高 减少k=3 → k=2。
3. 检索速度慢:
FAISS未启用GPU,安装faiss-gpu并配置CUDA环境。
4. 上传文件失败:
检查文件格式是否支持(PDF/DOCX/TXT等)。
确保 knowledge_base 目录有写入权限。
5. 回答不准确:
尝试增加 top_k(检索更多文档)或降低 dist_threshold(提高相似度要求)。
6. GPU内存不足:
启用4bit量化或在 config.py 中更换更小的模型(如 ChatGLM3-6B → DeepSeek-R1-1.5B)。
使用技巧
1. 修改身份设定:
编辑 knowledge_base/identity.md,修改AI的沟通风格(例如:“你是一个幽默的助手”)。
2. 调整检索参数:
在 app.py 的Gradio界面中,调整 top_k(检索文档数量)和 dist_threshold(相似度阈值)。
3. 自定义模型:
在 config.py 的 MODEL_PATHS 中添加新模型路径,重启服务后可在界面切换。
4. 批量上传与自动索引:
将多个文档(PDF、DOCX、TXT等)一次性拖入Gradio界面,系统会自动触发索引重建。
5. 特殊文档优先级:
将核心文档(如identity.md)命名为易识别的名称(如00_identity.md),确保其在检索中优先级更高。
6. 客户支持自动化:
在 identity.md 中设定客服话术如:开头使用“您好,感谢咨询!” 结尾添加“请问还有其他问题吗?😊”
7. 内部知识检索加速:(需要自己写哈)
为高频关键词(如“报销流程”)添加向量缓存。在 rag.py 中增加缓存逻辑。
8. 敏感信息过滤:(需要自己写哈)
在 generate_answer 函数中添加正则过滤。
9. 极端场景应对:(需要自己写哈)
在 loader.py 中处理大文件。(需要自己写哈)
比如:处理100MB+的巨型PDF文档?
可以使用 PyPDF2 按章节拆分PDF为小文件。
为每个章节生成摘要,存入索引。
用户提问时先检索摘要,再定位具体章节。
10. 混合语言文档处理:(需要自己写哈)
在 config.py 中切换多语言嵌入模型,
如 sentence-transformers/paraphrase-multilingual-mpnet-base-v2
在 retrieve_top_k_documents 中添加语言判断,
修改检索逻辑,按语言过滤文档。
11. 图文混合问答(扩展功能):(需要自己写哈)
从包含图的PDF中提取信息(如产品结构图),在 loader.py 中扩展PDF加载函数。
12. 语音问答功能:(需要自己写哈)
用户语音 --STT--> 文本问题 --RAG系统--> 文本回答 --TTS--> 语音输出
执行命令合集
1. 依赖下载:
pip install -r requirements.txt
2. CUDA torch 下载:
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
3. faiss 下载:
conda install faiss
4. 构建索引:
python main.py
5.python 进入:
python
6. BAAI/bge-m3 模型下载
from sentence_transformers import SentenceTransformer
model_name = "BAAI/bge-m3"
embedding_model = SentenceTransformer(model_name)
# 保存到本地
embedding_model.save("model/BAAI/bge-m3")
7. 退出 python 环境
Ctrl+Z 然后回车
8. API 服务访问
python api.py
9. 命令行 交互访问
python rag.py
10. Gradio 界面访问
python rag.py
11. 返回终端
exit 或者 Ctrl+C
暂时先这样吧,如果实在看不明白就留言,看到我会回复的。希望这个教程对您有帮助!
裁云为纸书胸臆,汲泉煮茶养性灵。
磨砚不觉时光逝,笔落犹存墨韵清。
忽闻沧海生龙啸,且驭长风破浪行。
踏峰一笑千帆过,万壑松风为我鸣。与君共勉。