吴恩达MCP课程(1):chat_bot

发布于:2025-05-31 ⋅ 阅读:(23) ⋅ 点赞:(0)

原课程代码是用Anthropic写的,下面代码是用OpenAI改写的,模型则用阿里巴巴的模型做测试
.env 文件为:

OPENAI_API_KEY=sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
OPENAI_API_BASE=https://dashscope.aliyuncs.com/compatible-mode/v1

完整代码

import arxiv
import json
import os
from typing import List
from dotenv import load_dotenv
import openai

PAPER_DIR = "papers"

def search_papers(topic: str, max_results: int = 5) -> List[str]:
    """
    Search for papers on arXiv based on a topic and store their information.
    
    Args:
        topic: The topic to search for
        max_results: Maximum number of results to retrieve (default: 5)
        
    Returns:
        List of paper IDs found in the search
    """
    
    # Use arxiv to find the papers
    client = arxiv.Client()
    
    # Search for the most relevant articles matching the queried topic
    search = arxiv.Search(
        query = topic,
        max_results = max_results,
        sort_by = arxiv.SortCriterion.Relevance
    )
    
    papers = client.results(search)
    
    # Create directory for this topic
    path = os.path.join(PAPER_DIR, topic.lower().replace(" ", "_"))
    os.makedirs(path, exist_ok=True)
    
    file_path = os.path.join(path, "papers_info.json")
    
    # Try to load existing papers info
    try:
        with open(file_path, "r") as json_file:
            papers_info = json.load(json_file)
    except (FileNotFoundError, json.JSONDecodeError):
        papers_info = {}
    
    # Process each paper and add to papers_info
    paper_ids = []
    for paper in papers:
        paper_ids.append(paper.get_short_id())
        paper_info = {
            'title': paper.title,
            'authors': [author.name for author in paper.authors],
            'summary': paper.summary,
            'pdf_url': paper.pdf_url,
            'published': str(paper.published.date())
        }
        
        papers_info[paper.get_short_id()] = paper_info
    
    # Save updated papers_info to json file
    with open(file_path, "w") as json_file:
        json.dump(papers_info, json_file, indent=2)
    
    print(f"Results are saved in: {file_path}")
    
    return paper_ids

def extract_info(paper_id: str) -> str:
    """
    Search for information about a specific paper across all topic directories.
    
    Args:
        paper_id: The ID of the paper to look for
        
    Returns:
        JSON string with paper information if found, error message if not found
    """
    
    for item in os.listdir(PAPER_DIR):
        item_path = os.path.join(PAPER_DIR, item)
        if os.path.isdir(item_path):
            file_path = os.path.join(item_path, "papers_info.json")
            if os.path.isfile(file_path):
                try:
                    with open(file_path, "r") as json_file:
                        papers_info = json.load(json_file)
                        if paper_id in papers_info:
                            return json.dumps(papers_info[paper_id], indent=2)
                except (FileNotFoundError, json.JSONDecodeError) as e:
                    print(f"Error reading {file_path}: {str(e)}")
                    continue
    
    return f"There's no saved information related to paper {paper_id}."

tools = [
    {
        "type": "function",
        "function": {
            "name": "search_papers",
            "description": "Search for papers on arXiv based on a topic and store their information",
            "parameters": {
                "type": "object",
                "properties": {
                    "topic": {
                        "type": "string",
                        "description": "The topic to search for"
                    },
                    "max_results": {
                        "type": "integer",
                        "description": "Maximum number of results to retrieve",
                        "default": 5
                    }
                },
                "required": ["topic"]
            }
        }
    },
    {
        "type": "function",
        "function": {
            "name": "extract_info",
            "description": "Search for information about a specific paper across all topic directories",
            "parameters": {
                "type": "object",
                "properties": {
                    "paper_id": {
                        "type": "string",
                        "description": "The ID of the paper to look for"
                    }
                },
                "required": ["paper_id"]
            }
        }
    }
]

mapping_tool_function = {
    "search_papers": search_papers,
    "extract_info": extract_info
}

def execute_tool(tool_name, tool_args):
    result = mapping_tool_function[tool_name](**tool_args)
    
    if result is None:
        result = "The operation completed but didn't return any results."
    
    elif isinstance(result, list):
        result = ', '.join(result)
    
    elif isinstance(result, dict):
        # Convert dictionaries to formatted JSON strings
        result = json.dumps(result, indent=2)
    
    else:
        # For any other type, convert using str()
        result = str(result)
    
    return result

load_dotenv()
client = openai.OpenAI(
    api_key = os.getenv("OPENAI_API_KEY"),
    base_url= os.getenv("OPENAI_API_BASE")
) 

def process_query(query):
    messages = [{"role": "user", "content": query}]
    
    response = client.chat.completions.create(
        model="qwen-turbo",  # 或其他OpenAI模型
        max_tokens=2024,
        tools=tools,
        messages=messages
    )
    
    process_query = True
    while process_query:
        # 获取助手的回复
        message = response.choices[0].message
        
        # 检查是否有普通文本内容
        if message.content:
            print(message.content)
            process_query = False
            
        # 检查是否有工具调用
        elif message.tool_calls:
            # 添加助手消息到历史
            messages.append({
                "role": "assistant", 
                "content": None,
                "tool_calls": message.tool_calls
            })
            
            # 处理每个工具调用
            for tool_call in message.tool_calls:
                tool_id = tool_call.id
                tool_name = tool_call.function.name
                tool_args = json.loads(tool_call.function.arguments)
                
                print(f"Calling tool {tool_name} with args {tool_args}")
                
                # 执行工具调用
                result = execute_tool(tool_name, tool_args)
                
                # 添加工具结果到消息历史
                messages.append({
                    "role": "tool",
                    "tool_call_id": tool_id,
                    "content": result
                })
            
            # 获取下一个回复
            response = client.chat.completions.create(
                model="qwen-turbo",  # 或其他OpenAI模型
                max_tokens=2024,
                tools=tools,
                messages=messages
            )
            
            # 如果只有文本回复,则结束处理
            if response.choices[0].message.content and not response.choices[0].message.tool_calls:
                print(response.choices[0].message.content)
                process_query = False

def chat_loop():
    print("Type your queries or 'quit' to exit.")
    while True:
        try:
            query = input("\nQuery: ").strip()
            if query.lower() == 'quit':
                break
                
            process_query(query)
            print("\n")
        except Exception as e:
            print(f"\nError: {str(e)}")

if __name__ == "__main__":
    chat_loop()


代码解释

导入模块

import arxiv        # 用于访问arXiv API搜索论文
import json         # 处理JSON数据
import os           # 操作系统功能,如文件路径处理
from typing import List  # 类型提示
from dotenv import load_dotenv  # 加载环境变量
import openai       # OpenAI API客户端

核心功能函数

1. search_papers 函数

这个函数用于在arXiv上搜索特定主题的论文并保存信息:

def search_papers(topic: str, max_results: int = 5) -> List[str]:
  • 参数
    • topic: 要搜索的主题
    • max_results: 最大结果数量(默认5个)
  • 返回值:找到的论文ID列表

功能流程

  1. 创建arXiv客户端
  2. 按相关性搜索主题相关论文
  3. 为该主题创建目录(如papers/machine_learning
  4. 尝试加载已有的论文信息(如果存在)
  5. 处理每篇论文,提取标题、作者、摘要等信息
  6. 将论文信息保存到JSON文件中
  7. 返回论文ID列表
2. extract_info 函数

这个函数用于在所有主题目录中搜索特定论文的信息:

def extract_info(paper_id: str) -> str:
  • 参数paper_id - 要查找的论文ID
  • 返回值:包含论文信息的JSON字符串(如果找到),否则返回错误信息

功能流程

  1. 遍历papers目录下的所有子目录
  2. 在每个子目录中查找papers_info.json文件
  3. 如果找到文件,检查是否包含指定的论文ID
  4. 如果找到论文信息,返回格式化的JSON字符串
  5. 如果未找到,返回未找到的提示信息

工具定义

tools = [...]

定义了两个函数工具,用于OpenAI API的工具调用:

  1. search_papers - 搜索论文
  2. extract_info - 提取论文信息

每个工具都定义了名称、描述和参数规范。

工具执行函数

def execute_tool(tool_name, tool_args):

这个函数负责执行指定的工具函数,并处理返回结果:

  • 将None结果转换为提示信息
  • 将列表结果转换为逗号分隔的字符串
  • 将字典结果转换为格式化的JSON字符串
  • 其他类型转换为字符串

OpenAI客户端初始化

load_dotenv()
client = openai.OpenAI(
    api_key = os.getenv("OPENAI_API_KEY"),
    base_url= os.getenv("OPENAI_API_BASE")
)

从环境变量加载API密钥和基础URL,初始化OpenAI客户端。

查询处理函数

def process_query(query):

这个函数处理用户的查询:

  1. 创建包含用户查询的消息列表
  2. 调用OpenAI API创建聊天完成
  3. 处理助手的回复:
    • 如果有普通文本内容,直接打印
    • 如果有工具调用,执行工具并将结果添加到消息历史
  4. 如果执行了工具调用,获取下一个回复
  5. 如果最终回复只有文本,打印并结束处理

聊天循环函数

def chat_loop():

这个函数实现了一个简单的聊天循环:

  1. 提示用户输入查询或输入’quit’退出
  2. 处理用户的查询
  3. 捕获并显示任何错误

主程序

if __name__ == "__main__":
    chat_loop()

当脚本直接运行时,启动聊天循环。

总结

这个脚本实现了一个基于OpenAI API的聊天机器人,它可以:

  1. 搜索arXiv上的论文并保存信息
  2. 提取已保存的论文信息
  3. 通过OpenAI API处理用户查询
  4. 支持工具调用功能,实现与arXiv的交互

运行示例

目录结构
在这里插入图片描述

运行结果
在这里插入图片描述
在这里插入图片描述