观察者模式与hook机制的联系

发布于:2024-09-17 ⋅ 阅读:(7) ⋅ 点赞:(0)

设计模式之观察者模式[C++版本]

感谢大佬文章的详细讲解

https://blog.csdn.net/leonardohaig/article/details/120187956

  1. 观察者模式需要注意的点
  • 观察者模式内部包含一个观察者类和一个被观察
  • 观察者类内部包含被观察者类对象,被观察者类内部包含观察者对象的列表
  • 观察者模式中的观察者一般会被主题维护一份引用列表。在设计时需要考虑观察者的生命周期,避免因为主题持有强引用导致观察者无法被垃圾回收。常见的做法是使用弱引用来存储观察者。
  • 提供取消订阅机制:观察者需要具备注册和取消注册的机制。确保观察者能够在不需要继续观察时及时退出,以避免不必要的通知和资源浪费
  1. C++代码示例
#include <bits/stdc++.h>

//
//观察者模式
//

class Observer;
//抽象被观察者
class Subject {
public:
    Subject() : m_nState(0) {}

    virtual ~Subject() = default;

    virtual void Attach(const std::shared_ptr<Observer> pObserver) = 0;  //在被观察者中注册观察者对象

    virtual void Detach(const std::shared_ptr<Observer> pObserver) = 0;  //从被观察者中移除观察者对象

    virtual void Notify() = 0;                                           //负责将被观察者的状态传递给目前的观察者列表中的观察者

    virtual int GetState() { return m_nState; }

    void SetState(int state)
    {
      std::cout << "Subject updated !" << std::endl;
      m_nState = state;
    }

protected:
    std::list<std::shared_ptr<Observer>> m_pObserver_list;               //被观察者列表,可以通过attach函数进行注册
    int m_nState;
};

//抽象观察者
class Observer
{
public:
    virtual ~Observer() = default;

    Observer(const std::shared_ptr<Subject> pSubject, const std::string &name = "unknown") : m_pSubject(pSubject), m_strName(name) {}

    virtual void Update() = 0;

    virtual const std::string &name() { return m_strName; }

protected:
    std::shared_ptr<Subject> m_pSubject;
    std::string m_strName;
};

//具体被观察者
class ConcreteSubject : public Subject
{
public:
    void Attach(const std::shared_ptr<Observer> pObserver) override
    {
      auto iter = std::find(m_pObserver_list.begin(), m_pObserver_list.end(), pObserver);
      if (iter == m_pObserver_list.end())
      {
        std::cout << "Attach observer" << pObserver->name() << std::endl;
        m_pObserver_list.emplace_back(pObserver);
      }

    }

    void Detach(const std::shared_ptr<Observer> pObserver) override
    {
      std::cout << "Detach observer" << pObserver->name() << std::endl;
      m_pObserver_list.remove(pObserver);
    }

    //循环通知所有观察者
    void Notify() override
    {
      auto it = m_pObserver_list.begin();
      while (it != m_pObserver_list.end())
      {
        (*it++)->Update();
      }
    }
};


//具体观察者1
class Observer1 : public Observer
{
public:
    Observer1(const std::shared_ptr<Subject> pSubject, const std::string &name = "unknown"): Observer(pSubject, name) {}

    void Update() override
    {
      std::cout << "Observer1_" << m_strName << " get the update.New state is: "
                << m_pSubject->GetState() << std::endl;
    }
};

//具体观察者2
class Observer2 : public Observer
{
public:
    Observer2(const std::shared_ptr<Subject> pSubject, const std::string &name = "unknown") : Observer(pSubject, name) {}

    void Update() override
    {
      std::cout << "Observer2_" << m_strName << " get the update.New state is: "
                << m_pSubject->GetState() << std::endl;
    }
};


int main() {
  // 创建被观察者
  std::shared_ptr<Subject> SubjectObject = std::make_shared<ConcreteSubject>();

  // 创建观察者
  std::shared_ptr<Observer> ObserverObject1_1 = std::make_shared<Observer1>(SubjectObject, "1");
  std::shared_ptr<Observer> ObserverObject1_2 = std::make_shared<Observer1>(SubjectObject, "2");
  std::shared_ptr<Observer> ObserverObject1_3 = std::make_shared<Observer1>(SubjectObject, "3");

  std::shared_ptr<Observer> ObserverObject2_1 = std::make_shared<Observer2>(SubjectObject, "4");
  std::shared_ptr<Observer> ObserverObject2_2 = std::make_shared<Observer2>(SubjectObject, "5");
  std::shared_ptr<Observer> ObserverObject2_3 = std::make_shared<Observer2>(SubjectObject, "6");

  // 注册观察者
  SubjectObject->Attach(ObserverObject1_1);
  SubjectObject->Attach(ObserverObject1_2);
  SubjectObject->Attach(ObserverObject1_3);
  SubjectObject->Attach(ObserverObject2_1);
  SubjectObject->Attach(ObserverObject2_2);
  SubjectObject->Attach(ObserverObject2_3);

  SubjectObject->SetState(2);// 改变状态
  SubjectObject->Notify();

  std::cout << std::string(50, '-') << std::endl;

  // 注销观察者
  SubjectObject->Detach(ObserverObject1_1);
  SubjectObject->Detach(ObserverObject2_1);

  SubjectObject->SetState(3);
  SubjectObject->Notify();

  return 0;
}

输出

Attach observer1
Attach observer2
Attach observer3
Attach observer4
Attach observer5
Attach observer6
Subject updated !
Observer1_1 get the update.New state is: 2
Observer1_2 get the update.New state is: 2
Observer1_3 get the update.New state is: 2
Observer2_4 get the update.New state is: 2
Observer2_5 get the update.New state is: 2
Observer2_6 get the update.New state is: 2
--------------------------------------------------
Detach observer1
Detach observer4
Subject updated !
Observer1_2 get the update.New state is: 3
Observer1_3 get the update.New state is: 3
Observer2_5 get the update.New state is: 3
Observer2_6 get the update.New state is: 3

创建观察者和被观察者。然后将观察者注册进被观察者的观察者列表。被观察者内部有一个"当前状态"为所有观察者想要得到的消息。因此观察者内部需要一个函数用来告诉每一个观察者目前的状态。

pytorch和mmdetection中的hook机制正是设计者模式的体现

pytorch中的hook机制介绍

在 PyTorch 中,hook 机制提供了一种在模型的前向传播和反向传播过程中插入自定义操作的方式。这可以帮助开发者在不改变模型本身的情况下,检查或修改模型的输入、输出和梯度信息。主要有两类 hook:前向传播 hook 和 反向传播 hook。

  1. 前向传播 hook (Forward Hook)
    前向传播 hook 可以在模块的前向传播阶段插入自定义操作,允许开发者访问和修改该模块的输入和输出。它通常用于调试和检查模型的行为。
  • 使用场景:检查每层的输入、输出,或者在中间层提取特征。
  • 用法:register_forward_hook(hook_fn):在模块的前向传播阶段注册 hook。每次调用该模块时,hook 会被触发。
  • hook 函数格式:
def hook_fn(module, input, output):
    # module: 当前层的模块
    # input: 输入张量(元组形式)
    # output: 输出张量
    # 可以在此函数中对输入或输出进行修改
  • 示例
import torch
import torch.nn as nn

# 定义一个简单的模型
model = nn.Sequential(
    nn.Linear(10, 5),
    nn.ReLU(),
    nn.Linear(5, 2)
)

# 定义一个 hook 函数
def forward_hook(module, input, output):
    print(f"Layer: {module}")
    print(f"Input: {input}")
    print(f"Output: {output}")

# 给第一层注册 forward hook
handle = model[0].register_forward_hook(forward_hook)

# 前向传播
x = torch.randn(1, 10)
out = model(x)

# 移除 hook
handle.remove()
  1. 反向传播 hook (Backward Hook)
    反向传播 hook 允许在反向传播过程中插入自定义操作。主要用于检查或修改梯度(gradients)。
  • 使用场景:分析每层的梯度变化,或在反向传播时对梯度进行操作(如梯度裁剪、正则化等)。
  • 用法:register_backward_hook(hook_fn):在模块的反向传播阶段注册 hook。
  • 注意:从 PyTorch 0.4 版本开始,register_backward_hook 已经不推荐使用,取而代之的是 register_full_backward_hook。
    hook 函数格式:
def hook_fn(module, grad_input, grad_output):
    # module: 当前层的模块
    # grad_input: 该层输入的梯度
    # grad_output: 该层输出的梯度
    # 可以在此函数中对梯度进行修改
  • 示例
def backward_hook(module, grad_input, grad_output):
    print(f"Layer: {module}")
    print(f"Grad Input: {grad_input}")
    print(f"Grad Output: {grad_output}")

# 给第一层注册 backward hook
handle = model[0].register_full_backward_hook(backward_hook)

# 前向传播与反向传播
x = torch.randn(1, 10, requires_grad=True)
out = model(x)
out.sum().backward()

# 移除 hook
handle.remove()
  1. 梯度 hook (Tensor Hook)
    除了在模块上注册 hook 之外,也可以在 张量 上注册 hook 来直接操作梯度。这种方式更灵活,主要用于反向传播时获取或修改某个张量的梯度。
  • 使用场景:在反向传播时监控或调整特定张量的梯度。
  • 用法:register_hook(hook_fn):为一个张量注册 hook。
  • hook 函数格式:
def hook_fn(grad):
    # grad: 当前张量的梯度
    # 可以对梯度进行操作或监控
  • 示例
# 定义一个张量
x = torch.randn(3, 3, requires_grad=True)

# 注册 hook 来监控梯度
def tensor_hook(grad):
    print(f"Gradient: {grad}")

handle = x.register_hook(tensor_hook)

# 前向传播与反向传播
y = x * 2
y.sum().backward()
  • 总结
    前向传播 hook:用于监控和修改模型每层的输入输出。
    反向传播 hook:用于监控和修改每层的梯度。
    张量 hook:用于在反向传播过程中直接监控或修改张量的梯度。
    这些 hook 机制非常适合用于调试、特征提取、模型解释或自定义优化策略。

mmcv hook机制

参考下文

https://zhuanlan.zhihu.com/p/355272220


网站公告

今日签到

点亮在社区的每一天
去签到