文章目录
前言
欢迎来到深度强化学习的世界!如果你对 Q-learning 有所了解,你可能会知道它使用一个表格(Q-table)来存储每个状态-动作对的价值。然而,当状态空间变得巨大,甚至是连续的时候(比如一个小车在轨道上的位置),Q-table 就变得不切实际。这时,深度Q网络(Deep Q-Network, DQN)就闪亮登场了。
DQN 的核心思想是用一个神经网络来代替 Q-table,实现从状态到(各个动作的)Q值的映射。这使得我们能够处理具有连续或高维状态空间的环境。本文将以经典的 CartPole-v1
环境为例,通过一个完整的 PyTorch 代码实现,带你深入理解 DQN 的工作原理及其关键组成部分:神经网络近似、经验回放和目标网络。
图 1 CartPole环境示意图
在 CartPole 环境中,智能体的任务是左右移动小车,以保持车上的杆子竖直不倒。这个环境的状态是连续的(车的位置、速度、杆的角度、角速度),而动作是离散的(向左或向右)。这正是DQN大显身手的完美场景。
让我们一起通过代码,揭开DQN的神秘面纱。
完整代码:下载链接
DQN 算法核心思想
在深入代码之前,我们先回顾一下 DQN 的几个关键概念。
Q-Learning 与函数近似
传统的 Q-learning 更新规则如下:
Q ( s , a ) ← Q ( s , a ) + α [ r + γ max a ′ ∈ A Q ( s ′ , a ′ ) − Q ( s , a ) ] Q(s,a)\leftarrow Q(s,a)+\alpha\left[r+\gamma\max_{a^{\prime}\in\mathcal{A}}Q(s^{\prime},a^{\prime})-Q(s,a)\right] Q(s,a)←Q(s,a)+α[r+γa′∈AmaxQ(s′,a′)−Q(s,a)]
当状态是连续的,我们无法用表格记录所有 Q(s,a)
。因此,我们引入一个带参数 w
的神经网络,即 Q-网络 Q ω ( s , a ) Q_\omega\left(s,a\right) Qω(s,a),来近似真实的 Q-函数。我们的目标是让网络预测的Q值 Q ω ( s , a ) Q_\omega\left(s,a\right) Qω(s,a) 逼近“目标Q值” r + γ max a ′ ∈ A Q ( s ′ , a ′ ) r+\gamma\max_{a^{\prime}\in\mathcal{A}}Q(s',a') r+γmaxa′∈AQ(s′,a′)。
为此,我们可以定义一个损失函数,最常见的就是均方误差(MSE Loss):
ω ∗ = arg min ω 1 2 N ∑ i = 1 N [ Q ω ( s i , a i ) − ( r i + γ max a ′ Q ω ( s i ′ , a ′ ) ) ] 2 \omega^*=\arg\min_\omega\frac{1}{2N}\sum_{i=1}^N\left[Q_\omega\left(s_i,a_i\right)-\left(r_i+\gamma\max_{a^{\prime}}Q_\omega\left(s_i^{\prime},a^{\prime}\right)\right)\right]^2 ω∗=argωmin2N1i=1∑N[Q