python 自动注册模式

发布于:2025-03-04 ⋅ 阅读:(17) ⋅ 点赞:(0)

自动注册模式

通过自动注册模式,减少if-else使得逻辑更清晰。

代码

import inspect
import sys
from typing import Any, Dict, Callable,List
from datetime import datetime


def message_handler(msg_type: str, priority: int = 0):
    def decorator(func: Callable) -> Callable:
        func.handle_type = msg_type
        func.priority = priority
        return func

    return decorator


def auto_register_handlers(module):
    handlers = {}
    for name, obj in inspect.getmembers(module):
        if inspect.isfunction(obj) and hasattr(obj, "handle_type"):
            msg_type = obj.handle_type
            if msg_type not in handlers:
                handlers[msg_type] = []
            handlers[msg_type].append(obj)
            # 按优先级排序
            handlers[msg_type].sort(key=lambda x: getattr(x, "priority", 0), reverse=True)
    return handlers


class MessageContext:
    def __init__(self,msg_type: str, content: Any, sender: str, timestamp: datetime):
        self.msg_type = msg_type
        self.content = content
        self.sender = sender
        self.timestamp = timestamp


# 定义各种消息处理器
@message_handler("text")
def handle_text(ctx: MessageContext) -> str:
    return f"[文本消息] {ctx.sender}: {ctx.content}"

@message_handler("image", priority=1)
def handle_image(ctx: MessageContext) -> str:
    return f"[图片消息] {ctx.sender} 分享了图片: {ctx.content}"

@message_handler("image")
def handle_image_like(ctx: MessageContext) -> str:
    return f"并且在{ctx.timestamp.strftime('%Y-%m-%d %H:%M')}获得了一个点赞。"

@message_handler("location")
def handle_location(ctx: MessageContext) -> str:
    return f"[位置消息] {ctx.sender} 分享了位置: {ctx.content}"


class ChatSystem:
    def __init__(self):
        self.handlers: Dict[str, List[Callable]] = auto_register_handlers(sys.modules[__name__])
        self.middlewares = []
        self.message_log = []

    def add_middleware(self, middleware):
        self.middlewares.append(middleware)

    def process_message(self, msg_type: str, content: Any, sender: str) -> str:
        context = MessageContext(msg_type, content, sender, datetime.now())

        # 执行中间件
        for middleware in self.middlewares:
            if not middleware.before_process(context):
                return f"{context.content}消息被中间件拦截"

        # 处理消息
        result = self._process_message(msg_type, context)

        # 后置中间件处理
        for middleware in reversed(self.middlewares):
            result = middleware.after_process(result, context)

        return result

    def _process_message(self, msg_type: str, context: MessageContext) -> str:
        handler = self.handlers.get(msg_type)
        if handler:
            result = [h(context) for h in handler]
            self.message_log.append((msg_type, context))
            return ", ".join(result)
        return f"不支持的消息类型:{msg_type}"


# 定义日志中间件
class LoggingMiddleware:
    def before_process(self, context):
        print(f"收到来自 {context.sender} 的消息")
        return True

    def after_process(self, result, context):
        print(f"消息处理完成: {result}")
        return result


# 定义敏感词过滤中间件
class SensitiveWordMiddleware:
    def __init__(self, sensitive_words):
        self.sensitive_words = sensitive_words

    def before_process(self, context):
        if isinstance(context.content, str):
            for word in self.sensitive_words:
                if word in context.content:
                    return False
        return True

    def after_process(self, result, context):
        return result


# 使用示例
if __name__ == "__main__":
    chat_system = ChatSystem()

    # 添加中间件
    chat_system.add_middleware(LoggingMiddleware())
    chat_system.add_middleware(SensitiveWordMiddleware(["敏感词1", "敏感词2"]))

    # 处理消息
    messages = [
        ("text", "你好,世界!", "张三"),
        ("image", "风景照片.jpg", "李四"),
        ("location", "北京市海淀区", "王五"),
        ("video", "视频文件.mp4", "赵六"),
        ("image", "敏感词1.jpg", "钱七"),
    ]

    for msg_type, content, sender in messages:
        result = chat_system.process_message(msg_type, content, sender)
        print(f"结果: {result}\n")