019_工具集成与外部API调用

发布于:2025-07-15 ⋅ 阅读:(13) ⋅ 点赞:(0)

工具集成与外部API调用

目录

工具集成概述

什么是工具集成

工具集成允许Claude与外部系统、API和服务进行交互,扩展其基础能力。通过工具集成,Claude可以执行计算、查询数据库、调用第三方服务等操作。

核心优势

能力扩展
  • 实时数据访问:获取最新的外部数据
  • 计算能力增强:执行复杂的数学和统计计算
  • 系统集成:与企业系统和数据库集成
  • 服务编排:协调多个外部服务
灵活性提升
  • 动态功能添加:根据需要添加新的工具
  • 自定义业务逻辑:实现特定的业务需求
  • 工作流自动化:自动化复杂的工作流程
  • 响应式交互:根据上下文智能选择工具
实用价值
  • 决策支持:基于实时数据做出决策
  • 效率提升:自动化重复性任务
  • 准确性保证:通过外部验证确保准确性
  • 用户体验:提供更丰富的交互体验

工具定义与配置

基本工具定义

简单计算工具
import anthropic
import json

def define_calculator_tool():
    """定义计算器工具"""
    
    return {
        "name": "calculator",
        "description": "执行数学计算,支持基本运算、三角函数、对数等",
        "input_schema": {
            "type": "object",
            "properties": {
                "expression": {
                    "type": "string",
                    "description": "要计算的数学表达式,如 '2 + 3 * 4' 或 'sqrt(16)'"
                }
            },
            "required": ["expression"]
        }
    }

def execute_calculator(expression):
    """执行计算器工具"""
    
    import math
    import re
    
    # 安全的数学函数
    safe_functions = {
        'sqrt': math.sqrt,
        'sin': math.sin,
        'cos': math.cos,
        'tan': math.tan,
        'log': math.log,
        'exp': math.exp,
        'abs': abs,
        'round': round,
        'pi': math.pi,
        'e': math.e
    }
    
    try:
        # 清理表达式,只允许安全的字符
        safe_expression = re.sub(r'[^0-9+\-*/().\s]', '', expression)
        
        # 替换函数名
        for func_name, func in safe_functions.items():
            safe_expression = safe_expression.replace(func_name, str(func))
        
        result = eval(safe_expression)
        
        return {
            "result": result,
            "expression": expression,
            "success": True
        }
        
    except Exception as e:
        return {
            "error": str(e),
            "expression": expression,
            "success": False
        }
数据库查询工具
def define_database_tool():
    """定义数据库查询工具"""
    
    return {
        "name": "database_query",
        "description": "执行SQL查询获取数据库中的信息",
        "input_schema": {
            "type": "object",
            "properties": {
                "query": {
                    "type": "string",
                    "description": "SQL查询语句,仅支持SELECT操作"
                },
                "database": {
                    "type": "string",
                    "description": "数据库名称",
                    "enum": ["users", "products", "orders", "analytics"]
                }
            },
            "required": ["query", "database"]
        }
    }

def execute_database_query(query, database):
    """执行数据库查询"""
    
    import sqlite3
    
    # 验证查询安全性
    if not is_safe_query(query):
        return {
            "error": "不安全的查询,仅允许SELECT操作",
            "success": False
        }
    
    try:
        # 连接到对应的数据库
        db_path = get_database_path(database)
        conn = sqlite3.connect(db_path)
        cursor = conn.cursor()
        
        cursor.execute(query)
        results = cursor.fetchall()
        
        # 获取列名
        column_names = [description[0] for description in cursor.description]
        
        conn.close()
        
        return {
            "results": results,
            "columns": column_names,
            "row_count": len(results),
            "success": True
        }
        
    except Exception as e:
        return {
            "error": str(e),
            "success": False
        }

def is_safe_query(query):
    """检查查询是否安全"""
    
    query_lower = query.lower().strip()
    
    # 只允许SELECT查询
    if not query_lower.startswith('select'):
        return False
    
    # 禁止的关键词
    forbidden_keywords = [
        'insert', 'update', 'delete', 'drop', 'create', 
        'alter', 'truncate', 'exec', 'execute'
    ]
    
    for keyword in forbidden_keywords:
        if keyword in query_lower:
            return False
    
    return True
网络API工具
def define_web_api_tool():
    """定义网络API工具"""
    
    return {
        "name": "web_api_call",
        "description": "调用外部Web API获取数据",
        "input_schema": {
            "type": "object",
            "properties": {
                "url": {
                    "type": "string",
                    "description": "API端点URL"
                },
                "method": {
                    "type": "string",
                    "description": "HTTP方法",
                    "enum": ["GET", "POST"]
                },
                "headers": {
                    "type": "object",
                    "description": "HTTP请求头"
                },
                "data": {
                    "type": "object",
                    "description": "请求数据(POST方法时使用)"
                }
            },
            "required": ["url", "method"]
        }
    }

def execute_web_api_call(url, method, headers=None, data=None):
    """执行Web API调用"""
    
    import requests
    import time
    
    # URL白名单验证
    if not is_allowed_url(url):
        return {
            "error": "URL不在允许的白名单中",
            "success": False
        }
    
    try:
        # 设置默认头部
        default_headers = {
            "User-Agent": "Claude-Assistant/1.0",
            "Accept": "application/json"
        }
        
        if headers:
            default_headers.update(headers)
        
        # 执行请求
        if method.upper() == "GET":
            response = requests.get(
                url, 
                headers=default_headers, 
                timeout=30
            )
        elif method.upper() == "POST":
            response = requests.post(
                url, 
                headers=default_headers, 
                json=data, 
                timeout=30
            )
        
        # 解析响应
        try:
            json_data = response.json()
        except:
            json_data = None
        
        return {
            "status_code": response.status_code,
            "headers": dict(response.headers),
            "data": json_data,
            "text": response.text if not json_data else None,
            "success": response.status_code < 400
        }
        
    except Exception as e:
        return {
            "error": str(e),
            "success": False
        }

def is_allowed_url(url):
    """检查URL是否在白名单中"""
    
    allowed_domains = [
        "api.openweathermap.org",
        "api.github.com",
        "jsonplaceholder.typicode.com",
        "httpbin.org"
    ]
    
    from urllib.parse import urlparse
    
    parsed_url = urlparse(url)
    domain = parsed_url.netloc
    
    return domain in allowed_domains

复合工具系统

文件处理工具集
def define_file_tools():
    """定义文件处理工具集"""
    
    return [
        {
            "name": "read_file",
            "description": "读取文件内容",
            "input_schema": {
                "type": "object",
                "properties": {
                    "file_path": {
                        "type": "string",
                        "description": "文件路径"
                    },
                    "encoding": {
                        "type": "string",
                        "description": "文件编码",
                        "default": "utf-8"
                    }
                },
                "required": ["file_path"]
            }
        },
        {
            "name": "write_file",
            "description": "写入文件内容",
            "input_schema": {
                "type": "object",
                "properties": {
                    "file_path": {
                        "type": "string",
                        "description": "文件路径"
                    },
                    "content": {
                        "type": "string",
                        "description": "要写入的内容"
                    },
                    "mode": {
                        "type": "string",
                        "description": "写入模式",
                        "enum": ["write", "append"],
                        "default": "write"
                    }
                },
                "required": ["file_path", "content"]
            }
        },
        {
            "name": "list_directory",
            "description": "列出目录内容",
            "input_schema": {
                "type": "object",
                "properties": {
                    "directory_path": {
                        "type": "string",
                        "description": "目录路径"
                    },
                    "include_hidden": {
                        "type": "boolean",
                        "description": "是否包含隐藏文件",
                        "default": False
                    }
                },
                "required": ["directory_path"]
            }
        }
    ]

class FileToolExecutor:
    def __init__(self, base_path="/safe/workspace"):
        self.base_path = base_path
        
    def execute_tool(self, tool_name, **kwargs):
        """执行文件工具"""
        
        if tool_name == "read_file":
            return self.read_file(**kwargs)
        elif tool_name == "write_file":
            return self.write_file(**kwargs)
        elif tool_name == "list_directory":
            return self.list_directory(**kwargs)
        else:
            return {"error": f"未知工具: {tool_name}", "success": False}
    
    def read_file(self, file_path, encoding="utf-8"):
        """读取文件"""
        
        safe_path = self.get_safe_path(file_path)
        if not safe_path:
            return {"error": "不安全的文件路径", "success": False}
        
        try:
            with open(safe_path, 'r', encoding=encoding) as f:
                content = f.read()
            
            return {
                "content": content,
                "file_path": file_path,
                "size": len(content),
                "success": True
            }
            
        except Exception as e:
            return {"error": str(e), "success": False}
    
    def write_file(self, file_path, content, mode="write"):
        """写入文件"""
        
        safe_path = self.get_safe_path(file_path)
        if not safe_path:
            return {"error": "不安全的文件路径", "success": False}
        
        try:
            file_mode = 'w' if mode == 'write' else 'a'
            
            with open(safe_path, file_mode, encoding='utf-8') as f:
                f.write(content)
            
            return {
                "file_path": file_path,
                "bytes_written": len(content.encode('utf-8')),
                "mode": mode,
                "success": True
            }
            
        except Exception as e:
            return {"error": str(e), "success": False}
    
    def get_safe_path(self, file_path):
        """获取安全的文件路径"""
        
        import os
        
        # 防止路径遍历攻击
        if '..' in file_path or file_path.startswith('/'):
            return None
        
        safe_path = os.path.join(self.base_path, file_path)
        
        # 确保路径在安全目录内
        if not safe_path.startswith(self.base_path):
            return None
        
        return safe_path

API集成模式

RESTful API集成

天气API集成
def define_weather_tool():
    """定义天气查询工具"""
    
    return {
        "name": "get_weather",
        "description": "获取指定城市的当前天气信息",
        "input_schema": {
            "type": "object",
            "properties": {
                "city": {
                    "type": "string",
                    "description": "城市名称,如'北京'或'Beijing'"
                },
                "units": {
                    "type": "string",
                    "description": "温度单位",
                    "enum": ["metric", "imperial"],
                    "default": "metric"
                }
            },
            "required": ["city"]
        }
    }

class WeatherAPI:
    def __init__(self, api_key):
        self.api_key = api_key
        self.base_url = "https://api.openweathermap.org/data/2.5"
    
    def get_weather(self, city, units="metric"):
        """获取天气信息"""
        
        import requests
        
        try:
            url = f"{self.base_url}/weather"
            params = {
                "q": city,
                "appid": self.api_key,
                "units": units,
                "lang": "zh_cn"
            }
            
            response = requests.get(url, params=params, timeout=10)
            
            if response.status_code == 200:
                data = response.json()
                
                return {
                    "city": data["name"],
                    "country": data["sys"]["country"],
                    "temperature": data["main"]["temp"],
                    "feels_like": data["main"]["feels_like"],
                    "humidity": data["main"]["humidity"],
                    "pressure": data["main"]["pressure"],
                    "description": data["weather"][0]["description"],
                    "wind_speed": data["wind"]["speed"],
                    "success": True
                }
            else:
                return {
                    "error": f"API请求失败: {response.status_code}",
                    "success": False
                }
                
        except Exception as e:
            return {
                "error": str(e),
                "success": False
            }
股票API集成
def define_stock_tool():
    """定义股票查询工具"""
    
    return {
        "name": "get_stock_price",
        "description": "获取股票价格和基本信息",
        "input_schema": {
            "type": "object",
            "properties": {
                "symbol": {
                    "type": "string",
                    "description": "股票代码,如'AAPL'、'MSFT'"
                },
                "interval": {
                    "type": "string",
                    "description": "数据间隔",
                    "enum": ["1d", "1wk", "1mo"],
                    "default": "1d"
                }
            },
            "required": ["symbol"]
        }
    }

class StockAPI:
    def __init__(self):
        self.base_url = "https://query1.finance.yahoo.com/v8/finance/chart"
    
    def get_stock_price(self, symbol, interval="1d"):
        """获取股票价格"""
        
        import requests
        
        try:
            url = f"{self.base_url}/{symbol}"
            params = {
                "interval": interval,
                "range": "1d"
            }
            
            response = requests.get(url, params=params, timeout=10)
            
            if response.status_code == 200:
                data = response.json()
                
                if data["chart"]["error"]:
                    return {
                        "error": "股票代码不存在或数据获取失败",
                        "success": False
                    }
                
                result = data["chart"]["result"][0]
                meta = result["meta"]
                
                return {
                    "symbol": symbol,
                    "company_name": meta.get("longName", symbol),
                    "current_price": meta["regularMarketPrice"],
                    "previous_close": meta["previousClose"],
                    "day_high": meta["regularMarketDayHigh"],
                    "day_low": meta["regularMarketDayLow"],
                    "volume": meta["regularMarketVolume"],
                    "currency": meta["currency"],
                    "exchange": meta["exchangeName"],
                    "success": True
                }
            else:
                return {
                    "error": f"API请求失败: {response.status_code}",
                    "success": False
                }
                
        except Exception as e:
            return {
                "error": str(e),
                "success": False
            }

GraphQL API集成

GitHub API集成
def define_github_tool():
    """定义GitHub查询工具"""
    
    return {
        "name": "github_query",
        "description": "查询GitHub仓库信息",
        "input_schema": {
            "type": "object",
            "properties": {
                "owner": {
                    "type": "string",
                    "description": "仓库所有者"
                },
                "repo": {
                    "type": "string",
                    "description": "仓库名称"
                },
                "query_type": {
                    "type": "string",
                    "description": "查询类型",
                    "enum": ["repository", "issues", "commits"],
                    "default": "repository"
                }
            },
            "required": ["owner", "repo"]
        }
    }

class GitHubAPI:
    def __init__(self, token=None):
        self.token = token
        self.base_url = "https://api.github.com"
        self.graphql_url = "https://api.github.com/graphql"
    
    def github_query(self, owner, repo, query_type="repository"):
        """查询GitHub信息"""
        
        if query_type == "repository":
            return self.get_repository_info(owner, repo)
        elif query_type == "issues":
            return self.get_issues(owner, repo)
        elif query_type == "commits":
            return self.get_commits(owner, repo)
    
    def get_repository_info(self, owner, repo):
        """获取仓库信息"""
        
        import requests
        
        try:
            url = f"{self.base_url}/repos/{owner}/{repo}"
            headers = {}
            
            if self.token:
                headers["Authorization"] = f"token {self.token}"
            
            response = requests.get(url, headers=headers, timeout=10)
            
            if response.status_code == 200:
                data = response.json()
                
                return {
                    "name": data["name"],
                    "full_name": data["full_name"],
                    "description": data["description"],
                    "language": data["language"],
                    "stars": data["stargazers_count"],
                    "forks": data["forks_count"],
                    "issues": data["open_issues_count"],
                    "created_at": data["created_at"],
                    "updated_at": data["updated_at"],
                    "license": data["license"]["name"] if data["license"] else None,
                    "success": True
                }
            else:
                return {
                    "error": f"仓库不存在或无权访问: {response.status_code}",
                    "success": False
                }
                
        except Exception as e:
            return {
                "error": str(e),
                "success": False
            }

工具执行流程

工具调用管理器

核心调用管理器
class ToolCallManager:
    def __init__(self):
        self.tools = {}
        self.executors = {}
        self.call_history = []
        
    def register_tool(self, tool_definition, executor):
        """注册工具"""
        
        tool_name = tool_definition["name"]
        self.tools[tool_name] = tool_definition
        self.executors[tool_name] = executor
        
        print(f"工具 '{tool_name}' 已注册")
    
    def execute_tool_call(self, tool_name, **kwargs):
        """执行工具调用"""
        
        if tool_name not in self.tools:
            return {
                "error": f"未知工具: {tool_name}",
                "success": False
            }
        
        # 验证参数
        validation_result = self.validate_parameters(tool_name, kwargs)
        if not validation_result["valid"]:
            return {
                "error": f"参数验证失败: {validation_result['error']}",
                "success": False
            }
        
        # 记录调用
        call_record = {
            "tool_name": tool_name,
            "parameters": kwargs,
            "timestamp": time.time()
        }
        
        try:
            # 执行工具
            executor = self.executors[tool_name]
            result = executor(**kwargs)
            
            # 记录结果
            call_record["result"] = result
            call_record["success"] = result.get("success", True)
            
            self.call_history.append(call_record)
            
            return result
            
        except Exception as e:
            call_record["error"] = str(e)
            call_record["success"] = False
            
            self.call_history.append(call_record)
            
            return {
                "error": str(e),
                "success": False
            }
    
    def validate_parameters(self, tool_name, parameters):
        """验证工具参数"""
        
        tool_def = self.tools[tool_name]
        schema = tool_def["input_schema"]
        
        # 检查必需参数
        required_params = schema.get("required", [])
        for param in required_params:
            if param not in parameters:
                return {
                    "valid": False,
                    "error": f"缺少必需参数: {param}"
                }
        
        # 检查参数类型
        properties = schema.get("properties", {})
        for param_name, param_value in parameters.items():
            if param_name in properties:
                expected_type = properties[param_name].get("type")
                if not self.check_parameter_type(param_value, expected_type):
                    return {
                        "valid": False,
                        "error": f"参数 {param_name} 类型错误,期望 {expected_type}"
                    }
        
        return {"valid": True}
    
    def check_parameter_type(self, value, expected_type):
        """检查参数类型"""
        
        type_mapping = {
            "string": str,
            "number": (int, float),
            "integer": int,
            "boolean": bool,
            "object": dict,
            "array": list
        }
        
        expected_python_type = type_mapping.get(expected_type)
        if expected_python_type:
            return isinstance(value, expected_python_type)
        
        return True

与Claude集成

完整的工具调用流程
def create_tool_enabled_conversation():
    """创建支持工具的对话"""
    
    # 初始化工具管理器
    tool_manager = ToolCallManager()
    
    # 注册工具
    calculator_tool = define_calculator_tool()
    tool_manager.register_tool(calculator_tool, execute_calculator)
    
    weather_api = WeatherAPI("your_api_key")
    weather_tool = define_weather_tool()
    tool_manager.register_tool(weather_tool, weather_api.get_weather)
    
    # 定义工具列表给Claude
    tools = list(tool_manager.tools.values())
    
    return tools, tool_manager

def handle_tool_conversation(user_message):
    """处理包含工具调用的对话"""
    
    tools, tool_manager = create_tool_enabled_conversation()
    
    client = anthropic.Anthropic(api_key="your-key")
    
    response = client.messages.create(
        model="claude-sonnet-4-20250514",
        max_tokens=2048,
        tools=tools,
        messages=[
            {
                "role": "user",
                "content": user_message
            }
        ]
    )
    
    # 处理响应中的工具调用
    final_response = process_tool_calls(response, tool_manager, client, tools)
    
    return final_response

def process_tool_calls(response, tool_manager, client, tools):
    """处理工具调用"""
    
    messages = [
        {
            "role": "user",
            "content": "请帮我处理这个请求"
        }
    ]
    
    # 添加助手的响应
    messages.append({
        "role": "assistant",
        "content": response.content
    })
    
    # 检查是否有工具调用
    tool_calls = []
    for content_block in response.content:
        if content_block.type == "tool_use":
            tool_calls.append(content_block)
    
    if not tool_calls:
        return response.content[0].text
    
    # 执行工具调用
    for tool_call in tool_calls:
        tool_name = tool_call.name
        tool_input = tool_call.input
        
        # 执行工具
        tool_result = tool_manager.execute_tool_call(tool_name, **tool_input)
        
        # 添加工具结果到消息
        messages.append({
            "role": "user",
            "content": [
                {
                    "type": "tool_result",
                    "tool_use_id": tool_call.id,
                    "content": json.dumps(tool_result, ensure_ascii=False)
                }
            ]
        })
    
    # 获取最终响应
    final_response = client.messages.create(
        model="claude-sonnet-4-20250514",
        max_tokens=2048,
        tools=tools,
        messages=messages
    )
    
    return final_response.content[0].text

错误处理机制

错误分类和处理

错误类型定义
class ToolError:
    """工具错误基类"""
    
    def __init__(self, message, error_type="general", tool_name=None):
        self.message = message
        self.error_type = error_type
        self.tool_name = tool_name
        self.timestamp = time.time()

class ValidationError(ToolError):
    """参数验证错误"""
    
    def __init__(self, message, parameter=None, tool_name=None):
        super().__init__(message, "validation", tool_name)
        self.parameter = parameter

class ExecutionError(ToolError):
    """工具执行错误"""
    
    def __init__(self, message, exception=None, tool_name=None):
        super().__init__(message, "execution", tool_name)
        self.exception = exception

class NetworkError(ToolError):
    """网络请求错误"""
    
    def __init__(self, message, status_code=None, tool_name=None):
        super().__init__(message, "network", tool_name)
        self.status_code = status_code

class SecurityError(ToolError):
    """安全相关错误"""
    
    def __init__(self, message, security_issue=None, tool_name=None):
        super().__init__(message, "security", tool_name)
        self.security_issue = security_issue
错误处理策略
class ErrorHandler:
    def __init__(self):
        self.error_strategies = {
            "validation": self.handle_validation_error,
            "execution": self.handle_execution_error,
            "network": self.handle_network_error,
            "security": self.handle_security_error
        }
        
        self.retry_settings = {
            "network": {"max_retries": 3, "backoff": 2},
            "execution": {"max_retries": 1, "backoff": 1}
        }
    
    def handle_error(self, error):
        """处理工具错误"""
        
        strategy = self.error_strategies.get(
            error.error_type, 
            self.handle_general_error
        )
        
        return strategy(error)
    
    def handle_validation_error(self, error):
        """处理验证错误"""
        
        return {
            "error": error.message,
            "error_type": "validation",
            "suggestion": "请检查参数格式和必需字段",
            "recoverable": True,
            "success": False
        }
    
    def handle_execution_error(self, error):
        """处理执行错误"""
        
        # 判断是否可以重试
        if error.tool_name in self.retry_settings.get("execution", {}):
            return {
                "error": error.message,
                "error_type": "execution",
                "suggestion": "工具执行失败,可以尝试重新执行",
                "recoverable": True,
                "retry_recommended": True,
                "success": False
            }
        
        return {
            "error": error.message,
            "error_type": "execution",
            "suggestion": "工具执行失败,请检查输入参数",
            "recoverable": False,
            "success": False
        }
    
    def handle_network_error(self, error):
        """处理网络错误"""
        
        if error.status_code in [500, 502, 503, 504]:
            # 服务器错误,可以重试
            return {
                "error": error.message,
                "error_type": "network",
                "suggestion": "网络服务暂时不可用,请稍后重试",
                "recoverable": True,
                "retry_recommended": True,
                "retry_delay": 5,
                "success": False
            }
        elif error.status_code in [401, 403]:
            # 认证错误,不可重试
            return {
                "error": error.message,
                "error_type": "network",
                "suggestion": "API认证失败,请检查API密钥",
                "recoverable": False,
                "success": False
            }
        else:
            return {
                "error": error.message,
                "error_type": "network",
                "suggestion": "网络请求失败,请检查网络连接",
                "recoverable": True,
                "success": False
            }
    
    def handle_security_error(self, error):
        """处理安全错误"""
        
        return {
            "error": "安全检查失败",
            "error_type": "security",
            "suggestion": "请求被安全策略阻止,请检查输入内容",
            "recoverable": False,
            "success": False
        }

重试机制

智能重试策略
class RetryManager:
    def __init__(self):
        self.retry_policies = {
            "default": {
                "max_retries": 3,
                "base_delay": 1,
                "max_delay": 60,
                "exponential_backoff": True,
                "jitter": True
            },
            "network": {
                "max_retries": 5,
                "base_delay": 2,
                "max_delay": 120,
                "exponential_backoff": True,
                "jitter": True
            },
            "rate_limit": {
                "max_retries": 10,
                "base_delay": 60,
                "max_delay": 600,
                "exponential_backoff": False,
                "jitter": False
            }
        }
    
    def execute_with_retry(self, func, *args, policy="default", **kwargs):
        """带重试的函数执行"""
        
        retry_policy = self.retry_policies.get(policy, self.retry_policies["default"])
        
        last_error = None
        
        for attempt in range(retry_policy["max_retries"] + 1):
            try:
                result = func(*args, **kwargs)
                
                # 如果结果表明成功,直接返回
                if isinstance(result, dict) and result.get("success", True):
                    return result
                
                # 如果是网络相关错误且可重试,继续重试
                if (isinstance(result, dict) and 
                    result.get("error_type") == "network" and
                    result.get("recoverable", False)):
                    
                    last_error = result
                    if attempt < retry_policy["max_retries"]:
                        delay = self.calculate_delay(attempt, retry_policy)
                        time.sleep(delay)
                        continue
                
                return result
                
            except Exception as e:
                last_error = {
                    "error": str(e),
                    "error_type": "execution",
                    "success": False
                }
                
                if attempt < retry_policy["max_retries"]:
                    delay = self.calculate_delay(attempt, retry_policy)
                    time.sleep(delay)
                else:
                    break
        
        # 所有重试都失败了
        return {
            "error": f"重试{retry_policy['max_retries']}次后仍然失败",
            "last_error": last_error,
            "success": False
        }
    
    def calculate_delay(self, attempt, policy):
        """计算重试延迟"""
        
        if policy["exponential_backoff"]:
            delay = policy["base_delay"] * (2 ** attempt)
        else:
            delay = policy["base_delay"]
        
        # 限制最大延迟
        delay = min(delay, policy["max_delay"])
        
        # 添加随机抖动
        if policy["jitter"]:
            import random
            delay = delay * (0.5 + random.random() * 0.5)
        
        return delay

安全性考虑

输入验证和清理

安全验证器
class SecurityValidator:
    def __init__(self):
        self.dangerous_patterns = [
            r'(?i)(drop|delete|truncate)\s+table',
            r'(?i)exec(ute)?\s*\(',
            r'(?i)script\s*>',
            r'(?i)<\s*script',
            r'(?i)javascript:',
            r'(?i)on\w+\s*=',
            r'\.\./|\.\\\.',
            r'(?i)file:///',
            r'(?i)http://localhost',
            r'(?i)127\.0\.0\.1'
        ]
        
        self.sql_injection_patterns = [
            r"('|(\\+\+);|(--(\\+\+);)",
            r"((\%27)|(\'))((\%6F)|o|(\%4F))((\%72)|r|(\%52))",
            r"((\%27)|(\'))union",
            r"exec(\s|\+)+(s|x)p\w+",
            r"union\s+select",
            r"insert\s+into",
            r"delete\s+from"
        ]
    
    def validate_input(self, input_data, input_type="general"):
        """验证输入数据"""
        
        if isinstance(input_data, str):
            return self.validate_string_input(input_data, input_type)
        elif isinstance(input_data, dict):
            return self.validate_dict_input(input_data)
        elif isinstance(input_data, list):
            return self.validate_list_input(input_data)
        
        return {"valid": True}
    
    def validate_string_input(self, text, input_type):
        """验证字符串输入"""
        
        # 检查危险模式
        for pattern in self.dangerous_patterns:
            if re.search(pattern, text):
                return {
                    "valid": False,
                    "error": "输入包含潜在危险内容",
                    "security_issue": "dangerous_pattern"
                }
        
        # 对SQL查询进行特殊检查
        if input_type == "sql":
            for pattern in self.sql_injection_patterns:
                if re.search(pattern, text, re.IGNORECASE):
                    return {
                        "valid": False,
                        "error": "检测到潜在的SQL注入",
                        "security_issue": "sql_injection"
                    }
        
        return {"valid": True}
    
    def validate_dict_input(self, data):
        """验证字典输入"""
        
        for key, value in data.items():
            if isinstance(value, str):
                result = self.validate_string_input(value)
                if not result["valid"]:
                    return result
            elif isinstance(value, (dict, list)):
                result = self.validate_input(value)
                if not result["valid"]:
                    return result
        
        return {"valid": True}
    
    def sanitize_input(self, input_data):
        """清理输入数据"""
        
        if isinstance(input_data, str):
            return self.sanitize_string(input_data)
        elif isinstance(input_data, dict):
            return {k: self.sanitize_input(v) for k, v in input_data.items()}
        elif isinstance(input_data, list):
            return [self.sanitize_input(item) for item in input_data]
        
        return input_data
    
    def sanitize_string(self, text):
        """清理字符串"""
        
        # 移除潜在危险字符
        import html
        
        # HTML编码
        sanitized = html.escape(text)
        
        # 移除控制字符
        sanitized = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', sanitized)
        
        return sanitized

权限控制

工具权限管理
class ToolPermissionManager:
    def __init__(self):
        self.permissions = {}
        self.user_roles = {}
        self.tool_requirements = {}
    
    def define_tool_permissions(self, tool_name, required_permissions):
        """定义工具所需权限"""
        
        self.tool_requirements[tool_name] = required_permissions
    
    def assign_user_role(self, user_id, role):
        """分配用户角色"""
        
        self.user_roles[user_id] = role
    
    def define_role_permissions(self, role, permissions):
        """定义角色权限"""
        
        self.permissions[role] = permissions
    
    def check_permission(self, user_id, tool_name):
        """检查用户是否有权限使用工具"""
        
        # 获取用户角色
        user_role = self.user_roles.get(user_id, "guest")
        
        # 获取角色权限
        user_permissions = self.permissions.get(user_role, [])
        
        # 获取工具要求的权限
        required_permissions = self.tool_requirements.get(tool_name, [])
        
        # 检查权限
        for required_permission in required_permissions:
            if required_permission not in user_permissions:
                return {
                    "allowed": False,
                    "missing_permission": required_permission,
                    "user_role": user_role
                }
        
        return {"allowed": True}
    
    def get_allowed_tools(self, user_id):
        """获取用户可以使用的工具列表"""
        
        user_role = self.user_roles.get(user_id, "guest")
        user_permissions = self.permissions.get(user_role, [])
        
        allowed_tools = []
        
        for tool_name, required_permissions in self.tool_requirements.items():
            if all(perm in user_permissions for perm in required_permissions):
                allowed_tools.append(tool_name)
        
        return allowed_tools

# 使用示例
def setup_permissions():
    """设置权限示例"""
    
    perm_manager = ToolPermissionManager()
    
    # 定义角色权限
    perm_manager.define_role_permissions("admin", [
        "read_files", "write_files", "execute_code", 
        "network_access", "database_access"
    ])
    
    perm_manager.define_role_permissions("user", [
        "read_files", "network_access"
    ])
    
    perm_manager.define_role_permissions("guest", [])
    
    # 定义工具权限要求
    perm_manager.define_tool_permissions("read_file", ["read_files"])
    perm_manager.define_tool_permissions("write_file", ["write_files"])
    perm_manager.define_tool_permissions("database_query", ["database_access"])
    perm_manager.define_tool_permissions("web_api_call", ["network_access"])
    
    return perm_manager

最佳实践

工具设计原则

1. 单一职责原则
# 好的例子:单一功能工具
def define_temperature_converter():
    return {
        "name": "convert_temperature",
        "description": "转换温度单位",
        "input_schema": {
            "type": "object",
            "properties": {
                "value": {"type": "number"},
                "from_unit": {"type": "string", "enum": ["C", "F", "K"]},
                "to_unit": {"type": "string", "enum": ["C", "F", "K"]}
            },
            "required": ["value", "from_unit", "to_unit"]
        }
    }

# 避免的例子:功能过于复杂的工具
def define_complex_tool():
    return {
        "name": "do_everything",  # 不好的设计
        "description": "执行各种操作:计算、网络请求、文件操作等",
        # 过于复杂的schema...
    }
2. 错误信息清晰化
def execute_with_clear_errors(func, **kwargs):
    """执行函数并提供清晰的错误信息"""
    
    try:
        result = func(**kwargs)
        return result
        
    except ValueError as e:
        return {
            "error": f"参数值错误: {str(e)}",
            "error_type": "parameter_error",
            "suggestion": "请检查输入参数的格式和范围",
            "success": False
        }
    
    except KeyError as e:
        return {
            "error": f"缺少必需参数: {str(e)}",
            "error_type": "missing_parameter",
            "suggestion": f"请提供参数 {str(e)}",
            "success": False
        }
    
    except Exception as e:
        return {
            "error": f"执行失败: {str(e)}",
            "error_type": "execution_error",
            "suggestion": "请检查输入或稍后重试",
            "success": False
        }
3. 性能优化
class OptimizedToolExecutor:
    def __init__(self):
        self.cache = {}
        self.rate_limiter = {}
        
    def execute_with_optimization(self, tool_name, **kwargs):
        """优化的工具执行"""
        
        # 检查缓存
        cache_key = self.generate_cache_key(tool_name, kwargs)
        if cache_key in self.cache:
            cache_result = self.cache[cache_key]
            if not self.is_cache_expired(cache_result):
                return cache_result["result"]
        
        # 检查速率限制
        if not self.check_rate_limit(tool_name):
            return {
                "error": "请求频率过高,请稍后重试",
                "error_type": "rate_limit",
                "success": False
            }
        
        # 执行工具
        result = self.execute_tool(tool_name, **kwargs)
        
        # 缓存结果
        if result.get("success", False):
            self.cache[cache_key] = {
                "result": result,
                "timestamp": time.time(),
                "ttl": 300  # 5分钟缓存
            }
        
        return result
    
    def generate_cache_key(self, tool_name, kwargs):
        """生成缓存键"""
        
        import hashlib
        import json
        
        cache_data = {
            "tool": tool_name,
            "params": kwargs
        }
        
        cache_str = json.dumps(cache_data, sort_keys=True)
        return hashlib.md5(cache_str.encode()).hexdigest()
    
    def is_cache_expired(self, cache_entry):
        """检查缓存是否过期"""
        
        return time.time() - cache_entry["timestamp"] > cache_entry["ttl"]
4. 监控和日志
class ToolMonitor:
    def __init__(self):
        self.metrics = {
            "total_calls": 0,
            "successful_calls": 0,
            "failed_calls": 0,
            "average_duration": 0,
            "error_rates": {}
        }
        
        self.call_logs = []
    
    def log_tool_call(self, tool_name, parameters, result, duration):
        """记录工具调用"""
        
        # 更新指标
        self.metrics["total_calls"] += 1
        
        if result.get("success", False):
            self.metrics["successful_calls"] += 1
        else:
            self.metrics["failed_calls"] += 1
            error_type = result.get("error_type", "unknown")
            self.metrics["error_rates"][error_type] = (
                self.metrics["error_rates"].get(error_type, 0) + 1
            )
        
        # 更新平均持续时间
        total_duration = (
            self.metrics["average_duration"] * (self.metrics["total_calls"] - 1) + 
            duration
        )
        self.metrics["average_duration"] = total_duration / self.metrics["total_calls"]
        
        # 记录详细日志
        log_entry = {
            "timestamp": time.time(),
            "tool_name": tool_name,
            "parameters": parameters,
            "result": result,
            "duration": duration,
            "success": result.get("success", False)
        }
        
        self.call_logs.append(log_entry)
        
        # 保持日志大小
        if len(self.call_logs) > 1000:
            self.call_logs = self.call_logs[-1000:]
    
    def get_performance_report(self):
        """获取性能报告"""
        
        if self.metrics["total_calls"] == 0:
            return {"message": "暂无调用记录"}
        
        success_rate = (
            self.metrics["successful_calls"] / self.metrics["total_calls"] * 100
        )
        
        return {
            "total_calls": self.metrics["total_calls"],
            "success_rate": f"{success_rate:.2f}%",
            "average_duration": f"{self.metrics['average_duration']:.3f}s",
            "error_breakdown": self.metrics["error_rates"],
            "most_common_errors": self.get_most_common_errors()
        }
    
    def get_most_common_errors(self):
        """获取最常见的错误"""
        
        error_rates = self.metrics["error_rates"]
        
        if not error_rates:
            return []
        
        sorted_errors = sorted(
            error_rates.items(), 
            key=lambda x: x[1], 
            reverse=True
        )
        
        return sorted_errors[:5]

通过合理的工具集成和外部API调用,可以显著扩展Claude的能力边界,创建更加强大和实用的AI应用系统。


网站公告

今日签到

点亮在社区的每一天
去签到