基于DeepSeek-R1 的RAG智能问答系统开发攻略

发布于:2025-03-22 ⋅ 阅读:(18) ⋅ 点赞:(0)

RAG为何成为技术热点?
RAG(Retrieval-Augmented Generation,检索增强生成) 结合了信息检索与生成模型的优势,能有效解决大模型“幻觉”问题,在智能客服、知识管理等领域应用广泛。本文将手把手教你搭建一个支持多格式文档上传、本地化部署的RAG系统,完整代码已开源,适合有一定Python基础的技术爱好者。

在这里插入图片描述

项目架构与核心技术

系统架构图

在这里插入图片描述

前端:Gradio交互界面 / FastAPI服务。

核心模块:文档加载 → 向量索引 → 检索增强生成。

底层支持:FAISS向量库、SentenceTransformer嵌入模型、本地LLM(如DeepSeek/ChatGLM)。

技术栈

在这里插入图片描述

向量检索:FAISS(Facebook开源的相似性搜索库)。

嵌入模型:BAAI/bge-m3(支持中英文的高效文本向量化)。

生成模型:DeepSeek-R1-1.5B(轻量级开源模型,适合本地部署)。

开发框架:Gradio(快速构建UI)、FastAPI(高性能API服务)。

环境搭建与模型准备

项目结构

Rag-System/
├─ model/					  	       # 模型存放文件夹
│   ├─ BAAI/bge-m3/                    # 本地Embedding模型的文件夹
│   └─ DeepSeek-R1-Distill-Qwen-1.5B/  # 本地LLM大语言模型的文件夹
│   └─ ChatGlm3-6B/                    # 本地LLM大语言模型的文件夹
├─ knowledge_base/
│   ├─ some_text.txt          # 你本地知识库中的各种文件(txt/pdf/docx等)
│   ├─ identity.md            # 自我认知文件
│   └─ ...
├─ cache/
│   └─ faiss_index/           # FAISS 索引缓存目录(会自动生成)
├─ icon/
│   ├─ bot.png                # agent 头像
│   └─ user.png               # 用户 头像
├─ config.py                  # 配置文件
├─ loader.py                  # 索引构建
├─ main.py                    # 多线程加载文档
├─ rag.py                     # 文档 检索、回答生成
├─ app.py                     # Gradio交互界面
├─ api.py                     # FastAPI服务端,提供REST接口
└── ... 其他文件 ...
在根目录创建 cache、icon、knowledge_base、model文件夹。

在这里插入图片描述

cache:用于存放 FAISS 索引缓存目录。

在这里插入图片描述

icon:用于存放 agent 和用户头像。

在这里插入图片描述

knowledge_base:用于存放知识库文件(里面要有一个名为 identity.md 的自我认知文件)。

在这里插入图片描述

model:用于存放模型文件。

在这里插入图片描述

Anaconda 环境搭建以及

1. 打开 Anaconda 重新创建一个环境。

在这里插入图片描述

2. 打开执行终端。

在这里插入图片描述

3. 导航到根目录。

在这里插入图片描述

4. 在根目录创建一个名为 requirements.txt 的文件,这个是依赖安装文件。

在这里插入图片描述

requirements.txt 内容。
langchain
openai
faiss-cpu
transformers
sentence-transformers
gradio
fastapi
uvicorn
python-multipart
pypdf
python-docx
pandas
openpyxl
python-pptx
fastapi
uvicorn
sentencepiece
5. 依赖安装
   pip install -r requirements.txt
   没有报错就证明没问题,有报错就看一下是哪个,单独下载就行。

在这里插入图片描述

6. CUDA 安装
   pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

在这里插入图片描述

模型准备

需要下载三个模型分别是:bge-m3、ChatGlm3-6B、DeepSeek-R1-Distill-Qwen-1.5B。
在model 文件夹下建立这样的三个文件夹,以便于存放模型文件。

在这里插入图片描述

BAAI/bge-m3 模型下载

BAAI/bge-m3:支持中英文的高效文本向量化
1. 在命令行键入:python 进入python 模式

在这里插入图片描述

2. 直接拷贝代码到命令行,等待模型下载。
from sentence_transformers import SentenceTransformer

model_name = "BAAI/bge-m3"
embedding_model = SentenceTransformer(model_name)

# 保存到本地
embedding_model.save("model/BAAI/bge-m3")

在这里插入图片描述

3. 这就是下载好的 bge-m3 模型。

在这里插入图片描述

4. Ctrl + Z 然后回车 返回 Anaconda 终端。

在这里插入图片描述

DeepSeek-R1-1.5B & ChatGlm3-6B 模型下载

下载地址: DeepSeek-R1-Distill-Qwen-1.5B
下载地址: ChatGlm3-6B

1. 模型下载。

在这里插入图片描述
在这里插入图片描述

2. DeepSeek-R1-Distill-Qwen-1.5B 模型 存放地址。

在这里插入图片描述

3. ChatGlm3-6B 存放地址。

在这里插入图片描述

核心代码解析

上传文档与索引重建

代码逻辑(app.py → upload_files函数):
	用户上传文件后,文件会被移动到 knowledge_base 目录。
	调用 rebuild_index()(main.py)重新构建FAISS索引。

问答流程

检索阶段(rag.py → retrieve_top_k_documents):
	用户问题 → 向量化 → FAISS检索最相关的3个文档。

生成阶段(rag.py → generate_answer):
	拼接检索到的文档和身份设定(identity.md)→ 生成回答(流式输出)。

模型切换

代码逻辑(app.py → switch_model函数):
	从 config.MODEL_PATHS 加载新模型,更新全局变量 tokenizer 和 model。

文档加载与索引构建(loader.py & main.py)

# loader.py:多线程加载文档(关键函数)
def load_documents(directory):
    with ThreadPoolExecutor(max_workers=4) as executor:
        futures = [executor.submit(load_file, filepath, filename) 
                   for filename in os.listdir(directory)]
        # 处理结果...

# main.py:构建FAISS索引
def rebuild_index():
    embeddings = embedding_model.encode(texts)  # 文本向量化
    index = faiss.IndexFlatL2(dim)             # 创建L2距离索引
    index.add(embeddings)                      # 添加向量
    faiss.write_index(index, "cache/faiss_index/docs.index")  # 保存索引

检索增强生成(rag.py)

def generate_answer(query):
    # 1. 检索Top-K文档
    related_docs = retrieve_top_k_documents(query, k=3)
    
    # 2. 拼接Prompt(身份设定+检索内容)
    prompt = f"""
    [系统角色设定]{identity_content}
    [检索结果]{related_docs}
    问题:{query}
    回答:
    """
    
    # 3. 流式生成回答
    for token in model.generate(**inputs, stream=True):
        yield token  # 逐词输出,提升用户体验

Gradio交互界面(app.py)

# 文件上传与索引更新
def upload_files(files, chatbot):
    shutil.move(file.name, "knowledge_base/new_file.pdf")  # 移动文件
    rebuild_index()  # 触发索引重建
    chatbot.append(("📁 上传完成", "索引已更新!"))

# 流式对话演示
with gr.Blocks() as interface:
    chatbot = gr.Chatbot(height=600)
    msg_input = gr.Textbox(label="请输入问题")
    submit_btn.click(chat_with_rag, [msg_input, chatbot], [chatbot])

相关完整代码

config.py

配置文件,定义模型路径、索引目录、知识库位置等参数。
# config.py
from pathlib import Path
import os

# 基础目录,默认为当前文件所在目录
BASE_DIR = Path(__file__).parent

class Config:
    # 嵌入模型路径:你可以替换成更适合中文的模型,例如 GanymedeNil/text2vec-base-chinese
    EMBEDDING_MODEL = str(BASE_DIR / "model" / "BAAI" / "bge-m3")

    # 默认模型相关参数
    DEFAULT_MAX_LENGTH = 4096
    CHUNK_SIZE = 1000
    OVERLAP = 200

    # FAISS 索引缓存目录
    FAISS_CACHE = BASE_DIR / "cache" / "faiss_index"
    
    # 可用的 LLM 模型路径,存放在 MODEL_PATHS 字典中
    MODEL_PATHS = {
        'DeepSeek-R1-1.5B': os.path.abspath(BASE_DIR / "model" / "DeepSeek-R1-Distill-Qwen-1.5B"),
        'ChatGLM3-6B': os.path.abspath(BASE_DIR / "model" / "ChatGlm3-6B")
    }

    # 默认使用的 LLM 模型(可以在此更改为其他键)
    DEFAULT_LLM_MODEL = 'DeepSeek-R1-1.5B'
    # 从 MODEL_PATHS 中取出默认模型路径,并转换为字符串
    LLM_MODEL_PATH = str(MODEL_PATHS[DEFAULT_LLM_MODEL])
    
    # 参考文档文件夹,用于存放系统的身份设定、常见文档等
    REFERENCE_FOLDER = BASE_DIR / "knowledge_base"
    # 确保目录存在
    REFERENCE_FOLDER.mkdir(parents=True, exist_ok=True) 
    # 系统身份设定文件(例如 identity.md),放在参考文档文件夹中
    IDENTITY_FILE = REFERENCE_FOLDER / "identity.md"

    # 其他参数设置
    MAX_HISTORY = 5
    STREAM_SEGMENT_SIZE = 5
    STREAM_DELAY = 0.1
    ADMIN_PASSWORD = os.getenv("ADMIN_PASSWORD", "Maddie")

    def __init__(self):
        # 确保缓存目录和参考文档文件夹存在
        self.FAISS_CACHE.mkdir(parents=True, exist_ok=True)
        self.REFERENCE_FOLDER.mkdir(parents=True, exist_ok=True)

# 实例化配置对象,后续代码中直接导入 cfg 使用
config = Config()

loader.py

多线程加载文档(PDF/DOCX/TXT等),返回文本内容。
# loader.py
import os
import logging
from typing import List
import pandas as pd
from pypdf import PdfReader
from docx import Document
from pptx import Presentation
from concurrent.futures import ThreadPoolExecutor, as_completed

# ========================
# 模块初始化配置
# 配置日志输出
# ========================
# 将第三方库(httpcore, urllib3)的日志等级设为WARNING,避免输出过多调试信息
logging.getLogger("httpcore").setLevel(logging.WARNING)
logging.getLogger("urllib3").setLevel(logging.WARNING)

# 获取当前模块的日志记录器
logger = logging.getLogger(__name__)

# 基本日志配置,设置日志记录等级和输出格式
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)

"""
加载并读取TXT文本文件内容,返回文件中的全部文本。
如果读取过程中出现异常,会记录错误并返回空字符串。
"""
def load_txt(filepath):
    try:
        with open(filepath, "r", encoding="utf-8") as file:
            return file.read()
    except Exception as e:
        logging.error(f"读取TXT文件 {filepath} 失败: {e}")
        return ""


"""
使用PdfReader读取PDF文件中的文本内容。
将PDF的每个页面提取出的文本进行拼接并返回。
如果读取过程中出现异常,会记录错误并返回空字符串。
"""
def load_pdf(filepath):
    try:
        reader = PdfReader(filepath)
        text = ""
        for page in reader.pages:
            page_text = page.extract_text()
            if page_text:
                text += page_text
        return text
    except Exception as e:
        logging.error(f"读取PDF文件 {filepath} 失败: {e}")
        return ""

"""
使用docx.Document读取DOCX文件中的段落文本。
将所有段落的文本用换行符拼接起来并返回。
如果读取过程中出现异常,会记录错误并返回空字符串。
"""
def load_docx(filepath):
    try:
        doc = Document(filepath)
        return "\n".join([para.text for para in doc.paragraphs])
    except Exception as e:
        logging.error(f"读取DOCX文件 {filepath} 失败: {e}")
        return ""

"""
使用pandas读取Excel文件,并将DataFrame转换为制表符分隔的字符串返回。
如果读取过程中出现异常,会记录错误并返回空字符串。
"""
def load_excel(filepath):
    try:
        df = pd.read_excel(filepath, engine='openpyxl')
        return df_to_text(df)
    except Exception as e:
        logging.error(f"读取Excel文件 {filepath} 失败: {e}")
        return ""

"""
使用pptx.Presentation读取PPTX文件。
遍历每一页(slide)和每个文本框(shape)的段落,将其文本内容按换行符拼接并返回。
如果读取过程中出现异常,会记录错误并返回空字符串。
"""
def load_pptx(filepath):
    try:
        prs = Presentation(filepath)
        text = ""
        for slide in prs.slides:
            for shape in slide.shapes:
                if shape.has_text_frame:
                    for paragraph in shape.text_frame.paragraphs:
                        text += paragraph.text + "\n"
        return text
    except Exception as e:
        logging.error(f"读取PPTX文件 {filepath} 失败: {e}")
        return ""

"""
将DataFrame转换为以制表符(\t)分隔的CSV格式字符串,并返回。
不包含索引(index)列。
"""
def df_to_text(df):
    return df.to_csv(index=False, sep='\t')

# 扩展名与对应的加载函数映射字典
# 根据文件后缀名选择合适的读取函数
LOADER_FUNCTIONS = {
    ".txt": load_txt,
    ".md": load_txt,   # Markdown 文件按文本文件处理,复用load_txt
    ".pdf": load_pdf,
    ".docx": load_docx,
    ".xlsx": load_excel,
    ".xls": load_excel,
    ".pptx": load_pptx,
}

"""
根据文件名的扩展名(ext)LOADER_FUNCTIONS中查找对应的加载函数进行读取。
如果不支持的文件类型,则记录警告并返回空字符串。
"""
def load_file(filepath, filename):
    ext = os.path.splitext(filename)[1].lower()  # 提取文件后缀并转换为小写
    loader = LOADER_FUNCTIONS.get(ext)
    if loader:
        return loader(filepath)
    else:
        logging.warning(f"不支持的文件类型: {filename}")
        return ""

"""
从指定目录(directory)中批量加载所有文件的内容,并返回一个包含字典的列表。
每个字典结构为:{"filename": 文件名, "content": 文件内容}

其中使用线程池(ThreadPoolExecutor)并行地读取文件以提高效率。
"""
def load_documents(directory):
    docs = []
    files = os.listdir(directory)
    logging.info(f"在目录 {directory} 找到 {len(files)} 个文件")

    # 使用线程池并发加载文件,max_workers=4 表示最多4个并发线程
    with ThreadPoolExecutor(max_workers=4) as executor:
        future_to_filename = {}
        
        # 将每个文件的加载任务提交给线程池
        for filename in files:
            filepath = os.path.join(directory, filename)
            if os.path.isfile(filepath):  # 只处理文件,不处理子目录
                future = executor.submit(load_file, filepath, filename)
                future_to_filename[future] = filename

        # 收集加载结果
        for future in as_completed(future_to_filename):
            filename = future_to_filename[future]
            try:
                content = future.result()  # 获取文件读取结果
                if content:
                    # 只有当读取到非空内容时,才追加到docs列表
                    docs.append({"filename": filename, "content": content})
                else:
                    logging.info(f"文件 {filename} 没有加载到内容")
            except Exception as e:
                logging.error(f"加载文件 {filename} 时出错: {e}")

    return docs

main.py

重建FAISS索引,将文档内容向量化存储。
# main.py
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import os
import pickle
from loader import load_documents  # 使用改进后的 loader 模块
from config import config
import logging

# ========================
# 模块初始化配置
# 配置日志输出
# ========================
logging.getLogger("httpcore").setLevel(logging.WARNING)
logging.getLogger("urllib3").setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)

# 加载预训练的嵌入模型(Embedding Model)
embedding_model_path = os.path.abspath(config.EMBEDDING_MODEL)
embedding_model = SentenceTransformer(embedding_model_path)

def rebuild_index():
    """重新加载所有文档,并重建 FAISS 索引"""
    print("🔄 开始重建 FAISS 索引...")

    # 使用配置中的知识库目录
    knowledge_dir = config.REFERENCE_FOLDER  # 这里改为 `REFERENCE_FOLDER`

    # 通过 loader 并发加载所有文档
    docs = list(load_documents(knowledge_dir))
    if not docs:
        print("⚠️ 没有找到文档,索引未更新。")
        return "⚠️ 没有找到文档,索引未更新。"

    print(f"📂 加载了 {len(docs)} 个文档")  # 打印文档数量

    # 提取每篇文档的内容和文件名
    texts = [doc["content"] for doc in docs]
    filenames = [doc["filename"] for doc in docs]

    # 将文档内容转化为向量表示
    embeddings = np.array(embedding_model.encode(texts))  # 确保是 NumPy 数组

    # 使用 FAISS 建立索引
    dim = embeddings.shape[1]  # 向量维度
    index = faiss.IndexFlatL2(dim)
    index.add(embeddings)

    # 保存 FAISS 索引和文档文件名列表
    faiss.write_index(index, str(config.FAISS_CACHE / "docs.index"))
    
    with open(str(config.FAISS_CACHE / "filenames.pkl"), "wb") as f:
        pickle.dump(filenames, f)

    print("✅ FAISS 索引已成功重建!")
    return "✅ FAISS 索引已成功重建!"

# 如果 `main.py` 直接运行,则自动创建索引
if __name__ == "__main__":
    rebuild_index()

rag.py

核心逻辑:检索文档、生成回答、总结/改写文本。
# rag.py
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"  # 禁用 Hugging Face 的并行警告
os.environ["BITSANDBYTES_NOWELCOME"] = "1"      # 禁止 bitsandbytes 欢迎信息
os.environ["WANDB_DISABLED"] = "true"           # 禁用 wandb 日志(如果有)

import faiss
import numpy as np
import pickle
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from sentence_transformers import SentenceTransformer
import pandas as pd
from config import config
import logging
from docx import Document  # 以防后面需要 docx 解析

##############################################################################
# 1. 加载模型与索引
##############################################################################

# 加载本地 Embedding 模型(SentenceTransformer),用于 FAISS 检索
embedding_model_path = os.path.abspath(config.EMBEDDING_MODEL)
embedding_model = SentenceTransformer(embedding_model_path)

# 加载 FAISS 索引
index = faiss.read_index(str(config.FAISS_CACHE / "docs.index"))
print(f"✅ FAISS 索引维度:{index.d}")

# 加载文件名列表
with open(str(config.FAISS_CACHE / "filenames.pkl"), "rb") as f:
    filenames = pickle.load(f)

# 加载本地 LLM,并使用 GPU 加速
llm_model_path = os.path.abspath(config.LLM_MODEL_PATH)
tokenizer = AutoTokenizer.from_pretrained(llm_model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    llm_model_path,
    torch_dtype=torch.float16,
    device_map="auto"
)

# 如果想手动指定 device,也可用 model.to(config.DEVICE),这里 device_map="auto" 通常就能工作

print("✅ 本地 LLM 模型加载完成!")

##############################################################################
# 2. 文档加载函数
##############################################################################

def load_document_content(filename):
    """
    根据文件名加载文本内容。identity.md 作为 AI 角色设定。
    """
    # 使用配置中的知识库目录
    file_path = os.path.join(config.REFERENCE_FOLDER, filename)


    # 特殊处理 identity.md
    if filename == "identity.md":
        with open(file_path, "r", encoding="utf-8") as file:
            return f"[系统角色提示]\n{file.read()}\n\n"

    if filename.endswith(".txt") or filename.endswith(".md"):
        with open(file_path, "r", encoding="utf-8") as file:
            return file.read()
    elif filename.endswith(".pdf"):
        from pypdf import PdfReader
        reader = PdfReader(file_path)
        return "\n".join([page.extract_text() for page in reader.pages if page.extract_text()])
    elif filename.endswith(".docx"):
        # 利用 python-docx 读取
        doc = Document(file_path)
        return "\n".join([para.text for para in doc.paragraphs])
    elif filename.endswith(".xlsx") or filename.endswith(".csv"):
        try:
            df = pd.read_excel(file_path) if filename.endswith(".xlsx") else pd.read_csv(file_path)
            return df.to_string(index=False)
        except Exception as e:
            return f"❌ 读取 {filename} 时出错: {str(e)}"
    elif filename.endswith(".pptx"):
        try:
            from pptx import Presentation
            prs = Presentation(file_path)
            text = []
            for slide in prs.slides:
                for shape in slide.shapes:
                    if shape.has_text_frame:
                        for paragraph in shape.text_frame.paragraphs:
                            text.append(paragraph.text)
            return "\n".join(text)
        except Exception as e:
            return f"❌ 读取 {filename} 时出错: {str(e)}"

    return "❌ 无法读取该文档格式:" + filename


##############################################################################
# 3. 检索函数
##############################################################################

def retrieve_top_k_documents(query, k=3):
    """
    根据查询语句在索引中找到最相关的 k 个文档,并返回其内容(截取)。
    """
    query_embedding = np.array(embedding_model.encode([query]))

    print(f"✅ 查询向量维度:{query_embedding.shape[1]}")

    _, idxs = index.search(query_embedding, k)

    retrieved_docs = []
    for i in idxs[0]:
        filename = filenames[i]
        content = load_document_content(filename)
        # 这里可以只截取前1000字符,避免 Prompt 过长
        retrieved_docs.append(f"📄【{filename}】\n{content[:100]}...")
    return retrieved_docs


##############################################################################
# 4. 分块 (Chunk) + 摘要 / 改写
##############################################################################

def chunk_text(text, chunk_size=1000, overlap=200):
    """
    将长文本分块,每块 chunk_size 个字符,并在块之间保留 overlap 个字符重叠,避免关键信息被切断。
    """
    chunks = []
    start = 0
    text_len = len(text)
    while start < text_len:
        end = min(start + chunk_size, text_len)  # 避免越界
        chunk = text[start:end]
        chunks.append(chunk)
        start = end - overlap

        if start < 0:
            start = 0
        if start >= text_len:
            break
    return chunks

def summarize_long_text(text):
    """
    对长文本进行多段式摘要,然后合并。
    """
    # 分块
    text_chunks = chunk_text(text, chunk_size=1500, overlap=200)
    chunk_summaries = []

    # 逐块摘要
    for idx, chunk in enumerate(text_chunks):
        prompt = f"请阅读以下文本内容,并进行简要总结:\n{chunk}\n\n总结:"
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        output = model.generate(
            **inputs,
            max_new_tokens=300,
            temperature=0.7,
            top_p=0.9,
            pad_token_id=tokenizer.eos_token_id
        )
        summary_chunk = tokenizer.decode(output[0], skip_special_tokens=True)
        chunk_summaries.append(summary_chunk.strip())

    # 合并所有块的摘要,再让模型做一次“总总结”
    combined_summary = "\n".join(chunk_summaries)
    final_prompt = f"以下是多个分块的总结,请将其合并为一个简洁的整体总结:\n{combined_summary}\n\n整体总结:"
    final_inputs = tokenizer(final_prompt, return_tensors="pt").to(model.device)
    final_output = model.generate(
        **final_inputs,
        max_new_tokens=500,
        temperature=0.7,
        top_p=0.9,
        pad_token_id=tokenizer.eos_token_id
    )
    final_summary = tokenizer.decode(final_output[0], skip_special_tokens=True)
    return final_summary.strip()

def rewrite_long_text(text):
    """
    对长文本进行“改写 / 润色”,与摘要类似的思路,先分块再合并。
    """
    # 先粗暴示例,也可以根据需要拆分做多次改写
    text_chunks = chunk_text(text, chunk_size=1500, overlap=200)
    rewrite_results = []

    for chunk in text_chunks:
        prompt = f"请对以下文本进行语言润色或改写,使其更通顺、简洁:\n{chunk}\n\n改写后:"
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        output = model.generate(
            **inputs,
            max_new_tokens=300,
            temperature=0.7,
            top_p=0.9,
            pad_token_id=tokenizer.eos_token_id
        )
        rewrite_chunk = tokenizer.decode(output[0], skip_special_tokens=True)
        rewrite_results.append(rewrite_chunk.strip())

    # 简单拼接,如果想再做最终合并,可以再来一次生成
    return "\n".join(rewrite_results).strip()


##############################################################################
# 5. 最终回答:generate_answer
#    - 如果用户问 "总结xx文件""改写xx文件":
#         -> 直接找到文件内容做 summarize/rewrite
#    - 否则做普通RAG问答
##############################################################################

def generate_answer(query):
    # 1) 检索文档
    related_docs = retrieve_top_k_documents(query, k=3)

    # 2) 加载 identity.md(如果存在)
    identity_content = ""
    if "identity.md" in filenames:
        identity_content = load_document_content("identity.md")

    # 3) 拼接上下文
    context = f"AI Identity/Persona\n{identity_content}\n\n【知识库检索结果】\n" + "\n\n".join(related_docs)

    # 构造 Prompt
    prompt = f"""
你是一名智能问答助手(扮演DeepSeek知识管家 Theodore 西奥-多尔),以下是你的身份描述和检索到的文档内容:
{context}

请遵守identity.md中的所有“沟通风格准则”和“特殊行为准则”,并根据文档做出回答。
如果在知识库中找不到答案,请回答“对不起,我在知识库中没有找到相关信息”。
请使用简洁且专业的口吻。
用户的问题:{query}

请直接给出简洁且专业的回答:
""".strip()

    # 流式生成回答
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    # 使用模型的流式生成参数(需transformers >=4.21.0)
    output_stream = model.generate(
        **inputs,
        max_new_tokens=300,
        temperature=0.7,
        top_p=0.9,
        pad_token_id=tokenizer.eos_token_id,
        do_sample=True
    )

    # 初始化输出
    generated_text = ""

    # 流式逐token输出
    for token_id in output_stream[0]:
        token_text = tokenizer.decode(token_id, skip_special_tokens=True)
        generated_text += token_text
        yield generated_text  # 实时返回当前生成内容

    # 完整的调试信息(生成完成后)
    debug_info = (
        "\n\n### 检索与推理过程\n\n"
        f"**用户问题**: {query}\n\n"
        f"**Prompt 内容**: \n{prompt}\n\n"
        "——以上信息仅供调试或进阶查看——\n"
    )

    # 最终输出带有调试信息(加上</think>标记)
    #print(f"{generated_text}</think>{debug_info}")
    yield f"{generated_text}</think>{debug_info}"



##############################################################################
# 6. 简单函数:_simple_summarize / _simple_rewrite (对短文本)
##############################################################################

def _simple_summarize(text):
    """
    对短文本做一次性摘要。如果文本不大,可直接用。
    """
    prompt = f"请阅读以下文本并进行简要总结:\n{text}\n\n总结:"
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    output = model.generate(
        **inputs, 
        max_new_tokens=300,
        temperature=0.7,
        top_p=0.9,
        pad_token_id=tokenizer.eos_token_id
    )
    return tokenizer.decode(output[0], skip_special_tokens=True).strip()

def _simple_rewrite(text):
    """
    对短文本做一次性改写。
    """
    prompt = f"请对以下文本进行语言润色或改写,使其更通顺、简洁:\n{text}\n\n改写后:"
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    output = model.generate(
        **inputs,
        max_new_tokens=300,
        temperature=0.7,
        top_p=0.9,
        pad_token_id=tokenizer.eos_token_id
    )
    return tokenizer.decode(output[0], skip_special_tokens=True).strip()


##############################################################################
# 7. 命令行交互入口
##############################################################################


import re

if __name__ == "__main__":
    print("📚 RAG + Summarize/Rewrite 示例程序启动...")
    try:
        while True:
            query = input("\n请输入您的问题(输入 'exit' 退出): ")
            if query.lower().strip() == "exit":
                print("\n👋 退出 RAG 系统,再见!")
                break

            # 初始化变量
            bot_response = ""
            debug_info = "暂无推理过程"
            current_output_buffer = ""

            # 流式处理生成器输出
            stream = generate_answer(query)
            for current_output in stream:
                current_output_buffer += current_output
                if "</think>" in current_output_buffer:
                    parts = current_output_buffer.split("</think>")
                    if len(parts) >= 2:
                        bot_response = parts[1].strip()  # 模型的回答
                        if len(parts) > 2:
                            debug_info = parts[2].strip()  # 调试信息
                        else:
                            debug_info = "暂无推理过程"
                        # 清空缓冲区
                        current_output_buffer = ""
                        # 打印流式输出的回答
                        print("\nAI 回答(流式输出):", bot_response, "\n")
                    else:
                        bot_response = parts[0].strip()
                        current_output_buffer = ""
                else:
                    # 如果没有找到 </think>,继续积累输出
                    continue

            # 如果缓冲区中还有内容,处理剩余部分
            if current_output_buffer:
                bot_response = current_output_buffer.strip()
                print("\nAI 回答(流式输出):", bot_response, "\n")

    except KeyboardInterrupt:
        print("\n👋 检测到 Ctrl + C,退出 RAG 系统!")


"""  模型回答 一次性输出
if __name__ == "__main__":
    print("📚 RAG + Summarize/Rewrite 示例程序启动...")
    try:
        while True:
            query = input("\n请输入您的问题(输入 'exit' 退出): ")
            if query.lower().strip() == "exit":
                print("\n👋 退出 RAG 系统,再见!")
                break

            # 初始化变量
            bot_response = ""
            debug_info = "暂无推理过程"

            # 流式处理生成器输出
            stream = generate_answer(query)
            for current_output in stream:
                if "</think>" in current_output:
                    parts = current_output.split("</think>")
                    if len(parts) >= 2:
                        bot_response = parts[1].strip()  # 模型的回答
                        if len(parts) > 2:
                            debug_info = parts[2].strip()  # 调试信息
                    else:
                        bot_response = parts[0].strip()
                else:
                    bot_response = current_output.strip()

            # 输出最终回答
            print("\nAI 回答:", bot_response, "\n")

    except KeyboardInterrupt:
        print("\n👋 检测到 Ctrl + C,退出 RAG 系统!")
"""

app.py

Gradio交互界面,支持上传文档、提问、切换模型、监控系统资源。
# app.py
import gradio as gr
import os
import torch
import psutil  # 用于系统监控
from rag import generate_answer
from config import config
from loader import load_documents
from transformers import AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer
import json
from main import rebuild_index  # ✅ 直接导入 `rebuild_index`
import shutil
import time

# 头像路径
USER_ICON_PATH = "icon/user.png"
BOT_ICON_PATH = "icon/bot.png"

# 1️ 加载嵌入模型(SentenceTransformer)
embedding_model = SentenceTransformer(config.EMBEDDING_MODEL)

# 2️ 加载 LLM 模型
def load_llm_model(model_name):
    """动态加载 LLM"""
    model_path = str(config.MODEL_PATHS.get(model_name, config.LLM_MODEL_PATH))

    # ✅ 确保 `trust_remote_code=True`
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.float16, device_map="auto")

    return tokenizer, model


# 预加载默认 LLM
current_model_name = config.DEFAULT_LLM_MODEL
tokenizer, model = load_llm_model(current_model_name)

# 3️ 处理文件上传
def upload_files(files, chatbot):
    """上传文件,更新 Chatbot 并确保返回符合 Gradio Chatbot 格式"""
    if not isinstance(files, list):
        files = [files]

    saved_files = []
    failed_files = []

    for file in files:
        try:
            original_filename = file.orig_name if hasattr(file, "orig_name") else os.path.basename(file.name)
            dest_path = os.path.join(config.REFERENCE_FOLDER, original_filename)

            # ✅ 将临时路径文件移动到 `knowledge_base` 目录
            shutil.move(file.name, dest_path)

            saved_files.append(original_filename)
        except Exception as e:
            failed_files.append(original_filename)
            print(f"❌ 上传失败: {original_filename}, 错误: {e}")

    # ✅ 只有至少有一个文件上传成功,才重建索引
    if saved_files:
        print("🔄 至少一个文件上传成功,开始重建索引...")
        index_message = rebuild_index()
    else:
        index_message = "⚠️ 所有文件上传失败,索引未更新。"

    # ✅ 构建返回消息
    message = f"📂 上传成功 {len(saved_files)} 个文件: {', '.join(saved_files)}"
    if failed_files:
        message += f"\n❌ 上传失败 {len(failed_files)} 个文件: {', '.join(failed_files)}"
    
    message += f"\n{index_message}"
    print(message)

    # ✅ 让 Chatbot 显示消息(格式必须是 `[(用户输入, 机器人回复)]`)
    chatbot.append(("📁 文件上传", message))
    return chatbot



# 4️ 处理对话
def chat_with_rag(question, chatbot, max_tokens, temperature, top_p, show_debug, topk_retrieval, dist_threshold):
    chatbot.append((question, ""))
    start_time = time.time()

    stream = generate_answer(question)
    bot_response = ""

    for current_output in stream:
        if "</think>" in current_output:
            parts = current_output.split("</think>")
            if len(parts) >= 2:
                bot_response = parts[1].strip()  # 模型的回答
                debug_info = parts[2].strip() if len(parts) > 2 else "暂无推理过程"
            else:
                bot_response = parts[0].strip()
                debug_info = "暂无推理过程"
        else:
            bot_response = current_output.strip()

        chatbot[-1] = (question, bot_response)
        yield "", chatbot

    # 推理时间计算
    elapsed_time = time.time() - start_time
    elapsed_str = f"🔍 点击查看推理过程,耗时 {elapsed_time:.2f} 秒 ⌄"

    # ✅ 仅对推理过程的字体进行淡化(去除对 bot_response 的额外颜色修改)
    if show_debug and debug_info:
        bot_response = (
            f"<details>"
            f"<summary style='color:#888;font-size:12px;'>{elapsed_str}</summary>"
            f"<div style='color:#ccc;background:#f5f5f5;padding:10px;border-radius:5px;'>\n\n"
            f"{debug_info}\n"
            f"</div></details>\n\n"
            f"{bot_response}"  # ✅ 去掉了 bot_response 外部的额外颜色设置
        )
        chatbot[-1] = (question, bot_response)
        yield "", chatbot


# 5️ 切换 LLM 模型
def switch_model(new_model):
    """切换 LLM 模型"""
    global tokenizer, model, current_model_name
    tokenizer, model = load_llm_model(new_model)
    current_model_name = new_model
    return f"✅ 已切换到 {new_model} 模型"

# 6️ 系统监控
def system_diagnosis():
    """返回 CPU、RAM、GPU 资源使用情况"""
    cpu_usage = psutil.cpu_percent()
    ram_usage = psutil.virtual_memory().percent
    gpu_usage = torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0
    return {"CPU 使用率": f"{cpu_usage}%", "内存使用率": f"{ram_usage}%", "GPU 占用": f"{gpu_usage:.2f}GB"}

# 7️ 导出 & 导入对话历史
def export_chat_history(chatbot):
    """导出对话历史为 JSON 文件"""
    history = [{"用户": msg[0], "AI": msg[1]} for msg in chatbot]
    file_path = "chat_history.json"
    with open(file_path, "w", encoding="utf-8") as f:
        json.dump(history, f, ensure_ascii=False, indent=2)
    return file_path

def import_chat_history(file):
    """导入对话历史"""
    if file is None:
        return []
    
    with open(file.name, "r", encoding="utf-8") as f:
        history = json.load(f)
    
    return [(msg["用户"], msg["AI"]) for msg in history]

# ========================
# 🚀 Gradio 界面构建
# ========================
def create_gradio_interface():
    """构建交互界面"""
    theme = gr.themes.Default(
        primary_hue="orange",
        secondary_hue="blue"
    )

    with gr.Blocks(theme=theme, title="DeepSeek RAG System 2.0") as interface:
        gr.Markdown("# 🔍 DeepSeek RAG 知识管理系统 (改进版)")

        # 对话区
        chatbot = gr.Chatbot(
            value=[(None, "您好!我是 Theodore(西奥-多尔),您的智能助手 🚀")],
            height=680,
            avatar_images=(USER_ICON_PATH, BOT_ICON_PATH)
        )
        msg_input = gr.Textbox(placeholder="输入您的问题...", lines=3)

        with gr.Row():
            submit_btn = gr.Button("💬 发送", variant="primary")
            os.environ["GRADIO_MAX_FILE_SIZE"] = "100mb"
            upload_btn = gr.UploadButton(
                "📁 上传文档",
                file_types=[".pdf", ".docx", ".txt", ".md", ".pptx", ".xlsx"],
                file_count="multiple"
            )
            clear_btn = gr.Button("🔄 清空对话")

        # 💻 系统监控
        with gr.Accordion("💻 系统监控", open=False):
            gr.Markdown("### 实时系统指标")
            diagnose_btn = gr.Button("🔄 刷新状态")
            status_panel = gr.JSON(label="系统状态", value={"状态": "正在获取..."})

        interface.load(system_diagnosis, inputs=None, outputs=status_panel)

        # 🤖 模型管理
        with gr.Accordion("🤖 模型管理", open=False):
            model_selector = gr.Dropdown(
                label="选择模型",
                choices=list(config.MODEL_PATHS.keys()),
                value=current_model_name
            )
            model_status = gr.Textbox(label="模型状态", interactive=False, value="正在初始化模型...")

        interface.load(lambda: switch_model(current_model_name), inputs=None, outputs=model_status)

        # 💬 对话历史
        with gr.Accordion("💬 对话历史", open=False):
            export_btn = gr.Button("导出历史")
            import_btn = gr.UploadButton("导入历史", file_types=[".json"])
            export_btn.click(export_chat_history, inputs=chatbot, outputs=gr.File())
            import_btn.upload(import_chat_history, inputs=import_btn, outputs=chatbot)

        # 📊 生成参数
        with gr.Accordion("📊 生成参数", open=False):
            max_tokens = gr.Slider(128, 4096, value=512, label="生成长度限制")
            temperature = gr.Slider(0.1, 1.0, value=0.7, label="创造性")
            top_p = gr.Slider(0.1, 1.0, value=0.9, label="核心采样")
            topk_retrieval = gr.Slider(1, 10, value=3, step=1, label="检索文档数量 top_k")
            dist_threshold = gr.Slider(0.0, 1.0, value=0.3, step=0.05, label="检索距离阈值")
            show_debug = gr.Checkbox(label="显示推理过程", value=True)


        # 绑定事件
        msg_input.submit(
            chat_with_rag,
            inputs=[msg_input, chatbot, max_tokens, temperature, top_p, show_debug, topk_retrieval, dist_threshold],
            outputs=[msg_input, chatbot]
        )
        submit_btn.click(
            chat_with_rag,
            inputs=[msg_input, chatbot, max_tokens, temperature, top_p, show_debug, topk_retrieval, dist_threshold],
            outputs=[msg_input, chatbot]
        )

        clear_btn.click(lambda: [(None, "对话已清空")], outputs=chatbot)


        upload_btn.upload(upload_files, inputs=[upload_btn, chatbot], outputs=[chatbot])

        diagnose_btn.click(system_diagnosis, outputs=status_panel)
        model_selector.change(switch_model, inputs=model_selector, outputs=model_status)

    return interface

# 启动 Gradio
if __name__ == "__main__":
    interface = create_gradio_interface()
    interface.launch(server_name="127.0.0.1", server_port=7860,share=True)

快速启动以及功能说明

文档向量生成

文档向量生成:
   将示例文档放入 knowledge_base 文件夹,运行以下命令构建索引:
   python main.py
   直接执行就会把索引构建。

在这里插入图片描述

会在 cahe\faiss_index 文件夹生成这两个文件。
	 docs.index:FAISS 索引。
	 filenames.pkl:文档文件名列表。

在这里插入图片描述

API 服务访问

启动API服务:
   python api.py
   访问 http://localhost:8000/ask 测试API接口。
   这样显示就证明 API 正常启动。

在这里插入图片描述

API 测试

便捷测试工具: Apifox 网址

1. 在Apifox 新建一个项目。

在这里插入图片描述

2. 创建一个项目名称。

在这里插入图片描述

3. 新建快捷请求。

请添加图片描述

4. 请求类型更改为Post。

在这里插入图片描述

5. 填入请求地址以及更改 Headers 类型和参数。

在这里插入图片描述

6. 更改 Body 并把raw 值添加:{"question": "你是谁"}

在这里插入图片描述

7. 保存当前快捷请求。

在这里插入图片描述

8. 保存当前快捷请求名称以及 存储根目录。

在这里插入图片描述

9. 点击发送测试当前接口是否正常返回,下面是正常状态。

在这里插入图片描述

10. 终端返回 200 OK 证明当前API 可正常访问。

在这里插入图片描述

11. Ctrl+C 退出 API

在这里插入图片描述

命令行 交互访问

启动 命令行交互界面
   python rag.py
   如果出现这个错误就执行:conda install faiss

在这里插入图片描述

下载完毕会这样显示。

在这里插入图片描述

再次执行 python rag.py

在这里插入图片描述

直接输出问题就行,就可以正常问答了。

在这里插入图片描述

等待一会就会输出答案了,目前是流式输出,想要一次性输出的话可以去 rag.py “命令行交互入口”替换。

在这里插入图片描述

当然退出也是 Ctrl+C 或直接键入exit 回车退出。

在这里插入图片描述

Gradio 服务访问

启动Gradio界面
   python app.py
   访问 http://localhost:7860 使用交互界面。
   这样显示就证明可以正常使用,如果没有自动跳出浏览器
   就在浏览器地址栏输出 http://localhost:7860 进行访问。

在这里插入图片描述

模型对话功能

默认显示界面。

在这里插入图片描述

在Textbox 输入问题点击发送就行模型问答。

在这里插入图片描述

文件上传功能

点击上传文档功能,选择想要上传的文件,目前支持.txt、.md、.pdf、.docx、.pptx、.xlsx文件。

请添加图片描述

上传成功会在终端显示索引构建完成。

请添加图片描述

当然 Gradio 也会有相应的显示。

请添加图片描述

清空对话功能

点击清空对话按钮会清空当前对话记录。

请添加图片描述

系统监控功能

打开系统监控会输出:"CPU 使用率","内存使用率","GPU 占用" 信息,可刷新状态。

请添加图片描述

模型切换功能

打开会展示当前对话模型。

请添加图片描述

可切换当前模型列表中的模型 模型列表在:config.py MODEL_PATHS下添加。

请添加图片描述

对话历史功能

可选择导出历史和导入历史:
  导出历史:导出当前对话历史。
  导入历史:将以往的对话记录进行导入,进行持续对话。

请添加图片描述

这是导出的json 数据,由于刚才点击了清空对话,所以没有记录。

在这里插入图片描述

生成参数功能

具体参数含义(1. 生成长度限制(:控制生成回答的最大长度(如最多生成多少字/词)。
		调整建议:
			回答太短 → 增大数值、回答冗余 → 减小数值。
			
	2. 创造性(:调节生成文本的随机性和多样性(类似 temperature 参数)。
		调整建议:
			需要严谨回答(如技术文档)→ 降低(如 0.3)
			需要创意内容(如故事生成)→ 提高(如 0.93. 核心采样(:通常指 top-p 采样,从概率最高的词中采样,平衡生成质量与多样性(例如 top_p=0.9 表示仅从概率最高的 90% 词汇中选择)。
		调整建议:
			需要精准回答 → 降低(如 0.7)
			需要多样化表达 → 提高(如 0.954. 检索文档数量top_k (:控制每次检索时返回的最相关文档数量。
		调整建议:
			需要更全面的上下文 → 增大(如 top_k=5)
			需要更精准的回答 → 减小(如 top_k=26. 检索距离阈值(:控制检索结果的相似度阈值,仅返回与问题向量距离小于该阈值的文档。
		调整建议:
			需要更严格的匹配 → 降低(如 0.2)
			需要更宽松的匹配 → 提高(如 0.57. 显示推理过程(:控制是否在回答中显示推理过程的详细信息(如检索到的文档、模型生成逻辑等)。
		调整建议:
			调试或开发阶段 → 开启(便于分析系统行为)
			生产环境或普通用户使用 → 关闭(提升用户体验)			

请添加图片描述

常见问题以及使用技巧

常见问题

1. 上传PDF后检索无结果:
   检查PyPDF版本,尝试pypdf==3.172. 回答内容重复:
   Top-K值过高	减少k=3 → k=23. 检索速度慢:
   FAISS未启用GPU,安装faiss-gpu并配置CUDA环境。

4. 上传文件失败:
   检查文件格式是否支持(PDF/DOCX/TXT等)。
   确保 knowledge_base 目录有写入权限。

5. 回答不准确:
   尝试增加 top_k(检索更多文档)或降低 dist_threshold(提高相似度要求)。

6. GPU内存不足:
   启用4bit量化或在 config.py 中更换更小的模型(如 ChatGLM3-6B → DeepSeek-R1-1.5B)。

使用技巧

1. 修改身份设定:
   编辑 knowledge_base/identity.md,修改AI的沟通风格(例如:“你是一个幽默的助手”)。

2. 调整检索参数:
   在 app.py 的Gradio界面中,调整 top_k(检索文档数量)和 dist_threshold(相似度阈值)。

3. 自定义模型:
   在 config.py 的 MODEL_PATHS 中添加新模型路径,重启服务后可在界面切换。

4. 批量上传与自动索引:
   将多个文档(PDFDOCXTXT等)一次性拖入Gradio界面,系统会自动触发索引重建。

5. 特殊文档优先级:
   将核心文档(如identity.md)命名为易识别的名称(如00_identity.md),确保其在检索中优先级更高。

6. 客户支持自动化:
   在 identity.md 中设定客服话术如:开头使用“您好,感谢咨询!” 结尾添加“请问还有其他问题吗?😊” 

7. 内部知识检索加速:(需要自己写哈)
   为高频关键词(如“报销流程”)添加向量缓存。在 rag.py 中增加缓存逻辑。
   
8. 敏感信息过滤:(需要自己写哈)
   在 generate_answer 函数中添加正则过滤。
    
9. 极端场景应对:(需要自己写哈)
   在 loader.py 中处理大文件。(需要自己写哈)
   比如:处理100MB+的巨型PDF文档?
   可以使用 PyPDF2 按章节拆分PDF为小文件。
   为每个章节生成摘要,存入索引。
   用户提问时先检索摘要,再定位具体章节。
    
10. 混合语言文档处理:(需要自己写哈)
    在 config.py 中切换多语言嵌入模型,  
    如 sentence-transformers/paraphrase-multilingual-mpnet-base-v2
    在 retrieve_top_k_documents 中添加语言判断,
    修改检索逻辑,按语言过滤文档。

11. 图文混合问答(扩展功能):(需要自己写哈)
    从包含图的PDF中提取信息(如产品结构图),在 loader.py 中扩展PDF加载函数。

12. 语音问答功能:(需要自己写哈)
    用户语音 --STT--> 文本问题 --RAG系统--> 文本回答 --TTS--> 语音输出

执行命令合集

1. 依赖下载:
   pip install -r requirements.txt

2. CUDA torch 下载:
   pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

3. faiss 下载:
   conda install faiss

4. 构建索引:
   python main.py
   
5.python 进入:
   python

6. BAAI/bge-m3 模型下载
   from sentence_transformers import SentenceTransformer
   model_name = "BAAI/bge-m3"
   embedding_model = SentenceTransformer(model_name)

   # 保存到本地
   embedding_model.save("model/BAAI/bge-m3")

7. 退出 python 环境
   Ctrl+Z 然后回车

8. API 服务访问
   python api.py
   
9. 命令行 交互访问
   python rag.py

10. Gradio 界面访问
   python rag.py
   
11. 返回终端
   exit 或者 Ctrl+C

暂时先这样吧,如果实在看不明白就留言,看到我会回复的。希望这个教程对您有帮助!
裁云为纸书胸臆,汲泉煮茶养性灵。
磨砚不觉时光逝,笔落犹存墨韵清。
忽闻沧海生龙啸,且驭长风破浪行。
踏峰一笑千帆过,万壑松风为我鸣。与君共勉。