基于python调用ollma的api模型接口.streamlit构建网页端,向量模型对相关文档进行处理。
import os
import asyncio
import chromadb
import aiohttp
import requests
import re
import yaml
from typing import List, Optional, Dict, Generator, Tuple
from datetime import datetime
from uuid import uuid4
from io import BytesIO
import openai
import streamlit as st
from streamlit.runtime.uploaded_file_manager import UploadedFile
from PyPDF2 import PdfReader
from docx import Document
import chardet
# --------------------------
# 配置管理
# --------------------------
class AppConfig:
def __init__(self):
self._config = None
self.load_time = None
@property
def config(self) -> dict:
"""带缓存机制的配置加载"""
if not self._config or (datetime.now() - self.load_time).seconds > 300:
self._load_config()
return self._config
def _load_config(self):
"""安全加载配置文件"""
try:
config_path = os.path.join(os.path.dirname(__file__), "config.yaml")
if not os.path.exists(config_path):
raise FileNotFoundError(f"配置文件不存在: {config_path}")
with open(config_path) as f:
config = yaml.safe_load(f)
required_keys = {"ollama", "embed_model", "vector_db_path"}
if missing := required_keys - config.keys():
raise KeyError(f"缺少必要配置项: {missing}")
self._config = config
self.load_time = datetime.now()
except Exception as e:
st.error(f"配置加载失败: {str(e)}")
st.stop()
config = AppConfig()
# --------------------------
# 向量数据库管理
# --------------------------
class VectorDBManager:
def __init__(self, path: str):
self.path = path
self.client = None
self.collection = None
def initialize(self):
"""带重试机制的数据库初始化"""
for retry in range(3):
try:
self.client = chromadb.PersistentClient(path=self.path)
self.collection = self.client.get_or_create_collection(
name="chat_history",
metadata={"hnsw:space": "cosine"}
)
return
except Exception as e:
if retry == 2:
raise RuntimeError(f"向量数据库初始化失败: {str(e)}")
time.sleep(2 ** retry)
try:
db_manager = VectorDBManager(config.config["vector_db_path"])
db_manager.initialize()
except Exception as e:
st.error(str(e))
st.stop()
# --------------------------
# 嵌入服务
# --------------------------
class EmbeddingService:
def __init__(self, base_url: str, model: str):
self.base_url = base_url
self.model = model
self.cache = {}
self.lock = asyncio.Lock()
async def get_embedding(self, text: str) -> Optional[List[float]]:
"""带缓存和重试机制的嵌入获取"""
sanitized_text = self._sanitize(text)
# 缓存检查
if sanitized_text in self.cache:
return self.cache[sanitized_text]
async with self.lock:
for attempt in range(3):
try:
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/api/embeddings",
json={"model": self.model, "prompt": sanitized_text},
timeout=aiohttp.ClientTimeout(total=15)
) as resp:
resp.raise_for_status()
data = await resp.json()
embedding = data.get("embedding")
if embedding:
self.cache[sanitized_text] = embedding
return embedding
except Exception as e:
if attempt == 2:
st.error(f"嵌入获取失败: {str(e)}")
await asyncio.sleep(1.5 ** attempt)
return None
@staticmethod
def _sanitize(text: str) -> str:
"""输入文本安全处理"""
return re.sub(r'[<>{}`]', '', text).strip()[:1000]
# --------------------------
# 文档处理管道
# --------------------------
class DocumentProcessor:
@staticmethod
async def process_files(files: List[UploadedFile], embed_service: EmbeddingService):
"""并行处理多个文件"""
semaphore = asyncio.Semaphore(3) # 控制并发数
async def process_file(file: UploadedFile):
async with semaphore:
try:
content = await DocumentParser.parse(file)
for chunk in TextSplitter.split(content):
if embedding := await embed_service.get_embedding(chunk):
db_manager.collection.add(
ids=[f"doc_{uuid4()}"],
embeddings=[embedding],
documents=[chunk],
metadatas={
"source": file.name,
"type": "uploaded_doc",
"timestamp": datetime.now().isoformat()
}
)
st.toast(f"✅ 成功处理: {file.name}")
except Exception as e:
st.error(f"❌ 处理失败 {file.name}: {str(e)}")
await asyncio.gather(*[process_file(f) for f in files])
class DocumentParser:
@staticmethod
async def parse(file: UploadedFile) -> str:
"""异步解析文档内容"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, DocumentParser._sync_parse, file)
@staticmethod
def _sync_parse(file: UploadedFile) -> str:
"""同步解析实现"""
content_bytes = file.getvalue()
encodings = ['utf-8', 'gbk', 'iso-8859-1']
try:
if file.type == "application/pdf":
return "\n".join(p.extract_text() for p in PdfReader(BytesIO(content_bytes)).pages)
if file.type == "text/plain":
for enc in encodings:
try:
return content_bytes.decode(enc)
except UnicodeDecodeError:
continue
if file.type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
return "\n".join(p.text for p in Document(BytesIO(content_bytes)).paragraphs)
except Exception as e:
raise ValueError(f"文档解析失败: {str(e)}")
raise ValueError("不支持的文档格式")
class TextSplitter:
@staticmethod
def split(text: str, chunk_size=500, overlap=100) -> List[str]:
"""智能文本分块"""
sentences = re.split(r'(?<=[。!?])', text)
chunks = []
current_chunk = []
current_length = 0
for sentence in sentences:
sentence = sentence.strip()
if not sentence:
continue
slen = len(sentence)
if current_length + slen > chunk_size:
chunks.append("".join(current_chunk))
current_chunk = current_chunk[-int(overlap / 20):] # 保留部分上文
current_length = sum(len(s) for s in current_chunk)
current_chunk.append(sentence)
current_length += slen
if current_chunk:
chunks.append("".join(current_chunk))
return chunks
# --------------------------
# LLM交互模块
# --------------------------
class ChatService:
def __init__(self, base_url: str):
self.base_url = base_url
self.client = openai.AsyncOpenAI(base_url=f"{base_url}/v1", api_key="no-key-required")
async def stream_response(self, messages: List[Dict], model: str, temperature: float) -> Generator[str, None, None]:
"""流式响应生成"""
try:
stream = await self.client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
stream=True
)
full_response = ""
async for chunk in stream:
if content := chunk.choices[0].delta.content:
full_response += content
yield content
# 保存对话历史
st.session_state.history.extend([
{"role": "user", "content": messages[-1]["content"]},
{"role": "assistant", "content": full_response}
])
except Exception as e:
yield f"⚠️ 生成错误: {str(e)}"
# --------------------------
# 用户界面组件
# --------------------------
class UIComponents:
@staticmethod
def setup_page():
st.set_page_config(
page_title="智能研究助手",
page_icon="🧠",
layout="wide",
initial_sidebar_state="expanded"
)
st.markdown("""
<style>
[data-testid="stSidebar"] {
background: #f5f7fb !important;
}
.stChatFloatingInputContainer {
bottom: 20px;
padding: 1rem;
background: white;
box-shadow: 0 4px 12px rgba(0,0,0,0.1);
border-radius: 12px;
}
</style>
""", unsafe_allow_html=True)
@staticmethod
def model_status(service_ok: bool):
status_color = "#4CAF50" if service_ok else "#FF5252"
st.sidebar.markdown(
f'<div style="padding: 8px; background: {status_color}; color: white; border-radius: 4px;">'
f'服务状态: {"正常" if service_ok else "异常"}'
'</div>',
unsafe_allow_html=True
)
@staticmethod
def chat_input_area():
with st.container():
cols = st.columns([0.85, 0.15])
with cols[0]:
prompt = st.text_input(label="", placeholder="输入您的问题...", key="input",
label_visibility="collapsed")
with cols[1]:
if st.button("发送", use_container_width=True):
return prompt
return None
# --------------------------
# 主应用逻辑
# --------------------------
async def fetch_models_from_ollama(base_url: str) -> Tuple[Dict[str, str], Dict[str, str]]:
"""从Ollama API获取模型列表,并区分嵌入模型和聊天模型"""
try:
response = requests.get(f"{base_url}/v1/models")
response.raise_for_status()
data = response.json()
# 假设响应结构为 {"data": [{"id": "model_id", "object": "model", "created": ..., "owned_by": ...}, ...]}
models = data.get('data', [])
embed_models = {model['id']: model.get('description', model['id']) for model in models if
"embed" in model['id']}
chat_models = {model['id']: model.get('description', model['id']) for model in models if
"embed" not in model['id']}
if not embed_models:
st.warning("API 响应中未找到任何嵌入模型,请检查 API 文档或服务器配置。")
if not chat_models:
st.warning("API 响应中未找到任何聊天模型,请检查 API 文档或服务器配置。")
return embed_models, chat_models
except requests.exceptions.HTTPError as http_err:
st.error(f"HTTP 错误: {http_err}")
except requests.exceptions.RequestException as req_err:
st.error(f"请求错误: {req_err}")
except Exception as e:
st.error(f"无法获取模型列表: {str(e)}")
return {}, {}
async def main():
# 初始化UI
UIComponents.setup_page()
# 初始化服务组件
ollama_config = config.config["ollama"]
embed_base_url = ollama_config["base_url"]
embed_model = config.config["embed_model"]
chat_base_url = ollama_config["base_url"] # 假设聊天模型使用相同的base_url
# 获取模型列表
embed_models, chat_models = await fetch_models_from_ollama(embed_base_url)
# 初始化会话状态
if "history" not in st.session_state:
st.session_state.update({
"history": [],
"current_embed_model": embed_model if embed_model in embed_models else next(
iter(embed_models)) if embed_models else None,
"current_chat_model": next(iter(chat_models)) if chat_models else None,
"temperature": 0.7,
"processing": False,
"uploaded_files": []
})
# 侧边栏组件
with st.sidebar:
st.title("设置")
UIComponents.model_status(True)
# 文件上传
uploaded_files = st.file_uploader(
"上传研究文档",
type=["pdf", "docx", "txt"],
accept_multiple_files=True,
key="file_uploader"
)
if uploaded_files and st.button("开始处理"):
st.session_state.uploaded_files = uploaded_files
with st.spinner('正在处理文档...'):
embed_service = EmbeddingService(embed_base_url, st.session_state.current_embed_model)
await DocumentProcessor.process_files(list(uploaded_files), embed_service)
# 模型参数
st.slider("温度参数", 0.0, 1.0, st.session_state.temperature, key="temp_slider",
on_change=lambda: st.session_state.__setitem__("temperature", st.session_state.temp_slider))
# 模型选择
if chat_models:
selected_chat_model = st.selectbox("选择聊天模型", options=list(chat_models.keys()), index=0,
key="chat_model_selector")
if selected_chat_model != st.session_state.current_chat_model:
st.session_state.current_chat_model = selected_chat_model
else:
st.warning("没有可用的聊天模型,请检查Ollama服务器配置。")
# 主界面
st.title("🧠 智能研究助手")
# 显示历史对话
for msg in st.session_state.history:
with st.chat_message(msg["role"], avatar=(
"https://img.alicdn.com/tfs/TB1oYRYwUT1gK0jSZFhXXaAtVXa-16-16.png" if msg[
"role"] == "user" else "https://img.alicdn.com/tfs/TB1ZLrwuET1gK0jSZSyXXXtlpXa-16-16.png")):
st.markdown(msg["content"])
# 处理用户输入
if prompt := UIComponents.chat_input_area():
if st.session_state.current_chat_model:
st.session_state.history.append({"role": "user", "content": prompt})
# 显示用户消息
with st.chat_message("user", avatar="https://img.alicdn.com/tfs/TB1oYRYwUT1gK0jSZFhXXaAtVXa-16-16.png"):
st.markdown(prompt)
# 生成响应
async def generate_and_show():
chat_service = ChatService(chat_base_url)
with st.chat_message("assistant",
avatar="https://img.alicdn.com/tfs/TB1ZLrwuET1gK0jSZSyXXXtlpXa-16-16.png") as assistant_msg:
placeholder = st.empty()
full_response = ""
try:
async for chunk in chat_service.stream_response(
messages=st.session_state.history,
model=st.session_state.current_chat_model,
temperature=st.session_state.temperature
):
full_response += chunk
placeholder.markdown(full_response + "▌")
placeholder.markdown(full_response)
except Exception as e:
placeholder.error(f"生成失败: {str(e)}")
await generate_and_show()
else:
st.warning("请选择一个有效的聊天模型。")
if __name__ == "__main__":
asyncio.run(main())