论文《基于自监督强化学习的神经李雅普诺夫函数逼近(SACLA)》
🧠 核心原理概述
1. 李雅普诺夫函数与其导数(Lyapunov Derivative)
李雅普诺夫函数 V ( x ) V(x) V(x) 满足:
- V ( x ) > 0 V(x) > 0 V(x)>0, ∀ x ≠ x ∗ \forall x \neq x^* ∀x=x∗
- V ( x ∗ ) = 0 V(x^*) = 0 V(x∗)=0
李导数 V ˙ ( x ) < 0 \dot{V}(x) < 0 V˙(x)<0 ⇒ 系统稳定
强化学习中通过世界模型近似李导数:
V ˙ ( x ) ≈ E x ^ ′ ∼ f ξ ( x , u ) [ V ( x ^ ′ ) ] − V ( x ) \dot{V}(x) \approx \mathop{\mathbb{E}}_{\hat{x}' \sim f_\xi(x,u)} \left[ V(\hat{x}') \right] - V(x) V˙(x)≈Ex^′∼fξ(x,u)[V(x^′)]−V(x)
2. 李雅普诺夫损失函数(Lyapunov Risk)
定义损失函数:
J V ( ψ ) = λ 1 ⋅ E x [ V ψ ( x ) 2 ⋅ 1 x = x ∗ ] + λ 2 ⋅ E ( x , u ) [ max ( V ˙ ( x , u ) , 0 ) ] J_V(\psi) = \lambda_1 \cdot \mathop{\mathbb{E}}_x \left[ V_\psi(x)^2 \cdot \mathbb{1}_{x=x^*} \right] + \lambda_2 \cdot \mathop{\mathbb{E}}_{(x,u)} \left[ \max(\dot{V}(x,u), 0) \right] JV(ψ)=λ1⋅Ex[Vψ(x)2⋅1x=x∗]+λ2⋅E(x,u)[max(V˙(x,u),0)]
用于训练网络参数 ψ \psi ψ。
3. 策略目标函数
强化学习目标添加正则项:
J π = E [ R ] + β ⋅ J V ( ψ ) J_\pi = \mathop{\mathbb{E}}[R] + \beta \cdot J_V(\psi) Jπ=E[R]+β⋅JV(ψ)
用于训练策略参数 ϕ \phi ϕ。
- 李雅普诺夫函数网络构建
- 世界模型估计系统动态
- Lyapunov 损失函数定义与策略优化
- 可视化 Lyapunov 曲面
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
# 设置随机种子
torch.manual_seed(0)
<torch._C.Generator at 0x2267cc40470>
1. 定义神经李雅普诺夫函数网络
class LyapunovNet(nn.Module):
def __init__(self, state_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, 64),
nn.ReLU(),
nn.Linear(64, 1)
)
def forward(self, x):
return torch.abs(self.net(x)) + 1e-4
2. 定义世界模型(World Model)
class WorldModel(nn.Module):
def __init__(self, state_dim, action_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim + action_dim, 64),
nn.ReLU(),
nn.Linear(64, state_dim)
)
def forward(self, x, u):
xu = torch.cat([x, u], dim=-1)
return self.net(xu)
3. Lyapunov 损失函数定义
def lyapunov_loss(V, f_model, x, u, x_star, lambda_1=1.0, lambda_2=1.0):
x_pred = f_model(x, u)
V_x = V(x)
V_x_pred = V(x_pred)
equilibrium_loss = (V(x_star) ** 2).mean()
V_dot = V_x_pred - V_x
lyapunov_violation = torch.clamp(V_dot, min=0.0).mean()
return lambda_1 * equilibrium_loss + lambda_2 * lyapunov_violation
4. 可视化 Lyapunov 曲面
def plot_lyapunov_surface(V, grid_range=1.0, steps=30):
x = torch.linspace(-grid_range, grid_range, steps)
X, Y = torch.meshgrid(x, x, indexing='ij')
states = torch.stack([X.reshape(-1), Y.reshape(-1)], dim=1)
Z = V(states).view(steps, steps).detach().numpy()
fig = plt.figure(figsize=(8,6))
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(X.numpy(), Y.numpy(), Z, cmap='viridis')
ax.set_xlabel('x1')
ax.set_ylabel('x2')
ax.set_zlabel('V(x)')
ax.set_title('Lyapunov Function Surface')
plt.show()
5. 示例:可视化 Lyapunov 函数
# 初始化网络并绘图
state_dim = 2
lyapunov_net = LyapunovNet(state_dim)
plot_lyapunov_surface(lyapunov_net)