多线程环境下的线程安全资源与缓存池设计:ThreadSafeObject 与 CachePool 实例解析

发布于:2025-07-01 ⋅ 阅读:(15) ⋅ 点赞:(0)

ThreadSafeObject 和 CachePool 的作用

✅ ThreadSafeObject

  • 定义:一个带有锁的资源封装容器。

  • 作用

    • 为某个对象加上线程锁(RLock),确保多线程下安全访问。
    • 支持通过 with obj.acquire(): 的方式对资源进行锁保护。
    • 可记录加载状态,防止重复加载。
  • 典型用途

    • 缓存中的模型、数据库连接、会话对象等资源。

✅ CachePool

  • 定义:一个带有线程锁和 LRU 管理机制的缓存池。

  • 作用

    • 管理多个 ThreadSafeObject 实例。
    • 支持 LRU 淘汰策略,限制缓存数量。
    • 提供线程安全的 get/set/pop/acquire 接口。
  • 典型用途

    • 多线程环境下管理多个模型或资源的共享缓存。
    • 避免重复加载大模型 / 数据集 / 知识库等。

🧠 总结对比

项目 作用 特点
ThreadSafeObject 封装一个资源并加锁 控制单个对象的线程访问
CachePool 管理多个加锁的资源 支持 LRU 缓存和线程安全访问管理

案例说明:共享字典的多线程更新器

import threading
import time
from collections import OrderedDict

# 简化版的 ThreadSafeObject
class ThreadSafeObject:
    def __init__(self, key, obj):
        self._key = key
        self._obj = obj
        self._lock = threading.RLock()

    def acquire(self):
        return self._lock

    def get(self):
        with self._lock:
            return self._obj

    def set(self, value):
        with self._lock:
            self._obj = value

# 简化版的 CachePool
class CachePool:
    def __init__(self, max_size=3):
        self._cache = OrderedDict()
        self._max_size = max_size
        self._lock = threading.RLock()

    def set(self, key, obj):
        with self._lock:
            if key in self._cache:
                self._cache.move_to_end(key)
            self._cache[key] = obj
            if len(self._cache) > self._max_size:
                self._cache.popitem(last=False)  # LRU 淘汰
            print(f"[CachePool] Cache keys: {list(self._cache.keys())}")

    def get(self, key):
        with self._lock:
            return self._cache.get(key)

# 示例:共享计数器,多个线程安全递增
def increment(pool: CachePool, key: str):
    for _ in range(1000):
        item = pool.get(key)
        if item:
            with item.acquire():
                current = item.get()
                item.set(current + 1)

if __name__ == "__main__":
    pool = CachePool()
    pool.set("counter", ThreadSafeObject("counter", 0))

    threads = [threading.Thread(target=increment, args=(pool, "counter")) for _ in range(5)]

    for t in threads:
        t.start()
    for t in threads:
        t.join()

    print(f"Final counter value: {pool.get('counter').get()}")

💡 解读:

•	ThreadSafeObject 保证一个线程在访问这个“计数器”的时候,其他线程不能干扰。
•	CachePool 管理多个这样的对象(这里只用了一个),并支持 LRU 清除老的缓存。
•	即便多个线程同时访问 counter,最终值仍然是精确的 5000,说明线程安全生效。

网站公告

今日签到

点亮在社区的每一天
去签到