【深度强化学习】(2) Double DQN 模型解析,附Pytorch完整代码
创始人
2024-06-01 08:32:21
0

大家好,今天和大家分享一个深度强化学习算法 DQN 的改进版 Double DQN,并基于 OpenAI 的 gym 环境库完成一个小游戏,完整代码可以从我的 GitHub 中获得:

https://github.com/LiSir-HIT/Reinforcement-Learning/tree/main/Model


1. 算法原理

1.1 DQN 原理回顾

DQN 算法的原理是指导机器人不断与环境交互,理解最佳的行为方式,最终学习到最优的行为策略,机器人与环境的交互过程如下图所示。 

机器人与环境的交互过程是机器人在 t 时刻,采取动作 a_t 并作用于环境,然后环境从 t 时刻状态 s_t 转变到 t+1 时刻状态 s_{t+1},同时奖励函数对 a_t 进行评价得到奖励值 r_t。机器人根据 r_t 不断优化行为轨迹,最终学习到最优的行为策略。

整个强化学习过程可简化为马尔可夫决策过程(MDP)。其中所有状态均具备马尔可夫性。MDP 可用一个五元组 \left \langle S,A,P,R,\gamma \right \rangle 表示,各元素意义如下: 

(1)S 是有限状态集合,即机器人在环境中探索到的所有可能状态。s_t 表示机器人在 t 时刻的状态。

(2)A 是有限动作集合,即智能体根据 s_t 采取的所有可能动作集合。a_t 表示机器人在t 时刻状态采取的行为

(3)P 是状态转移概率,定义如下:P_{s_{t+1}}^a = P[S_{t+1}=s_{t+1}|S_t=s, A_t=a]

(4)R 是奖励函数,即机器人基于 s_t 采取 a_t 后获得的期望奖励,定义如下:R_s^a = E[R_{t+1}|s_{t+1}|S_t=s, A_t=a]

(5)γ 代表折扣因子,值域 [0,1],即未来的期望奖励在当前时刻的价值比例。

MDP 中,价值函数包括状态价值函数和动作价值函数。动作价值函数(Q 函数)表示在策略 \pi 的指导下,根据状态 s,采取行为 a 所获得的期望回报策略 \pi 表示状态到行为的映射,相当于机器人的决策策略,并根据不同的状态选择不同的行为,即 a = \pi (s)。机器人在策略 \pi 的指导下,Q 函数的定义如下:

Q_\pi (s,a)=E_\pi [G_t|S_t=s,A_t=a]

式中 G_t 表示折扣奖励,定义如下:

G_t=R_{t+1} + \gamma R_{t+2} + ... = \sum_{k=t}^{T}\gamma ^{k-t}R(s_k,a_k)

式中 \gamma ^t 随训练过程迭代减小,\gamma ^t 越小表示未来的奖励对当前时刻的奖励影响越小。

Q 函数的贝尔曼方程表示如下:

Q_\pi (s_t,a_t)=E[R(s_t,a_t)+\gamma E_{a_{t+1\sim \pi }}Q_\pi (s_{t+1}, a_{t+1})]

式中,R(s_t,a_t) 表示机器人在 s_t 时采取 a_t 所获得的即时奖励,等式右侧第二项表示机
器人执行策略 \pi 产生的未来累计奖励的期望

通过选取最大动作价值函数求解最优行为策略的公式表示如下: 

a = argmax_{a\in A}Q(s,a)

DQN 算法通过行为的奖励值构造算法训练的标签,并且其中的经验回放(Experience Replay)和目标网络有效的解决了数据相关性和非静态分布的问题。DQN 算法的结构示意图如下。

DQN 的网络结构由目标网络和估计网络组成,这两个网络的结构相同但参数不同估计网络具有最新的网络参数,计算当前状态-动作对的价值,并定期更新目标网络的参数,使其计算目标 Q 值。双网络结构打破了数据之间的相关性,使DQN 学习不同的数据分布。经验回放部分储存了智能体(机器人)的历史行为信息,其中包括多组行为序列对 (s,a,r,s'),即当前时刻状态 s ,行为 a ,奖励 r 以及下一时刻状态 s'。当 DQN 算法更新时,随机从经验池中抽取部分行为序列对进行经验回放,这种方式解决了经验池中数据相关性强和非静态分布导致的模型泛化能力差的问题。

DQN 算法通过贪婪法直接获得目标 Q 值贪婪法通过最大化方式使 Q 值快速向可能的优化目标收敛但易导致过估计Q 值的问题,使模型具有较大的偏差。DQN 算法过估计 Q 值的问题不适合机器人操作行为的研究,采用 Double DQN 算法解耦动作的选择和目标 Q 值的计算,以解决过估计 Q 值的问题。 


1.2 Double DQN 原理

Double  DQN 算法是 DQN 算法的改进版本,解决了 DQN 算法过估计行为价值的问题。DQN 算法中,某一时刻状态为非终止状态时,目标 Q 值的计算公式如下所示:

y_j = r_j + \gamma max_{a'}Q(s_{j+1},a';\theta ')

Double  DQN 算法不直接通过最大化的方式选取目标网络计算的所有可能 Q 值,而是首先通过估计网络选取最大 Q 值对应的动作,公式表示如下: 

a_{max} = argmax_a Q (s_{t+1}, a ; \theta )

然后目标网络根据 a_{max} 计算目标 Q 值,公式表示如下:

y_j = r_j+\gamma Q(s_{j+1},a_{max};\theta ')

最后将上面两个公式结合,目标 Q 值的最终表示形式如下:

y_j = r_j + \gamma Q(s_{j+1},argmax_aQ(s_{t+1,a;\theta });\theta ')

目标是最小化目标函数,即最小化估计 Q 值和目标 Q 值的差值,公式如下:

\delta = |Q(s_t,a_t)-y_t|=|Q(s_t,a_t;\theta ) - (r_t + \gamma Q(S_{t+1},argmax_aQ(s_{t+1},a;\theta );\theta ')) |

结合目标函数,损失函数定义如下:

loss=\begin{Bmatrix} \frac{1}{2}\delta ^2 & for |\delta | \leqslant 1 \\ |\delta |-\frac{1}{2} & otherwize \end{Bmatrix}

Double DQN 的伪代码:

Double DQN 算法结构如下。在 Double DQN 框架中存在两个神经网络模型,分别是训练网络与目标网络。这两个神经网络模型的结构完全相同,但是权重参数不同;每训练一段之间后,训练网络的权重参数才会复制给目标网络。训练时,训练网络用于估计当前的 Q(s_t,a),而目标网络用于估计 max_aQ(s_{t+1},a),这样就能保证真实值 Q_{target}(s_t,a) 的估计不会随着训练网络的不断自更新而变化过快。此外,DQN 还是一种支持离线学习的框架,即通过构建经验池的方式离线学习过去的经验。将均方误差 MSE(Q_{train}, Q_{target}) 作为训练模型的损失函数,通过梯度下降法进行反向传播,对训练模型进行更新;若干轮经验池采样后,再将训练模型的权重赋给目标模型,以此进行 Double DQN 框架下的模型自学习。 


2. 代码实现

模型构建部分的代码如下:

import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
import collections  # 队列
import random# ----------------------------------- #
#(1)经验回放池
# ----------------------------------- #class ReplayBuffer:def __init__(self, capacity):# 创建一个队列,先进先出,队列长度不变self.buffer = collections.deque(maxlen=capacity)# 填充经验池def add(self, state, action, reward, next_state, done):self.buffer.append((state, action, reward, next_state, done))# 随机采样batch组样本数据def sample(self, batch_size):transitions = random.sample(self.buffer, batch_size)# 分别取出这些数据,*获取list中的所有值state, action, reward, next_state, done = zip(*transitions)# 将state变成数组,后面方便计算return np.array(state), action, reward, np.array(next_state), done# 队列的长度def size(self):return len(self.buffer)# ----------------------------------- #
#(2)构造网络,训练网络和目标网络共用该结构
# ----------------------------------- #class Net(nn.Module):def __init__(self, n_states, n_hiddens, n_actions):super(Net, self).__init__()# 只有一个隐含层self.fc1 = nn.Linear(n_states, n_hiddens)self.fc2 = nn.Linear(n_hiddens, n_actions)# 前向传播def forward(self, x):x = self.fc1(x)  # [b,n_states]-->[b,n_hiddens]x = self.fc2(x)  # [b,n_hiddens]-->[b,n_actions]return x# ----------------------------------- #
#(3)模型构建
# ----------------------------------- #class Double_DQN:#(1)初始化def __init__(self, n_states, n_hiddens, n_actions,learning_rate, gamma, epsilon,target_update, device):# 属性分配self.n_states = n_statesself.n_hiddens = n_hiddensself.n_actions = n_actionsself.learning_rate = learning_rateself.gamma = gammaself.epsilon = epsilonself.target_update = target_updateself.device = device# 记录迭代次数self.count = 0# 实例化训练网络self.q_net = Net(self.n_states, self.n_hiddens, self.n_actions)# 实例化目标网络self.target_q_net = Net(self.n_states, self.n_hiddens, self.n_actions)# 优化器,更新训练网络的参数self.optimizer = torch.optim.Adam(self.q_net.parameters(), lr=self.learning_rate)#(2)动作选择def take_action(self, state):# numpy[n_states]-->[1, n_states]-->Tensorstate = torch.Tensor(state[np.newaxis, :])# print('--------------------------')# print(state.shape)# 如果小于贪婪系数就取最大值reward最大的动作if np.random.random() < self.epsilon:# 获取当前状态下采取各动作的rewardactions_value = self.q_net(state)# 获取reward最大值对应的动作索引action = actions_value.argmax().item()# 如果大于贪婪系数就随即探索else:action = np.random.randint(self.n_actions)return action#(3)获取每个状态对应的最大的state_valuedef max_q_value(self, state):# list-->tensor[3]-->[1,3]state = torch.tensor(state, dtype=torch.float).view(1,-1)# 当前状态对应的每个动作的reward的最大值 [1,3]-->[1,11]-->intmax_q = self.q_net(state).max().item()return max_q#(4)网络训练def update(self, transitions_dict):# 当前状态,array_shape=[b,4]states = torch.tensor(transitions_dict['states'], dtype=torch.float)# 当前状态的动作,tuple_shape=[b]==>[b,1]actions = torch.tensor(transitions_dict['actions'], dtype=torch.int64).view(-1,1)# 选择当前动作的奖励, tuple_shape=[b]==>[b,1]rewards = torch.tensor(transitions_dict['rewards'], dtype=torch.float).view(-1,1)# 下一个时刻的状态array_shape=[b,4]next_states = torch.tensor(transitions_dict['next_states'], dtype=torch.float)# 是否到达目标 tuple_shape=[b,1]dones = torch.tensor(transitions_dict['dones'], dtype=torch.float).view(-1,1)# 当前状态[b,4]-->当前状态采取的动作及其奖励[b,2]-->actions中是每个状态下的动作索引# -->当前状态s下采取动作a得到的state_valueq_values = self.q_net(states).gather(1, actions)# 获取动作索引# .max(1)输出tuple每个特征的最大state_value及其索引,[1]获取的每个特征的动作索引shape=[b]max_action = self.q_net(next_states).max(1)[1].view(-1,1)# 下个状态的state_value。下一时刻的状态输入到目标网络,得到每个动作对应的奖励,使用训练出来的action索引选取最优动作max_next_q_values = self.target_q_net(next_states).gather(1, max_action)# 目标网络计算出的,当前状态的state_valueq_targets = rewards + self.gamma * max_next_q_values * (1-dones)# 预测值和目标值的均方误差损失dqn_loss = torch.mean(F.mse_loss(q_values, q_targets))# 梯度清零self.optimizer.zero_grad()# 梯度反传dqn_loss.backward()# 更新训练网络的参数self.optimizer.step()# 更新目标网络参数if self.count % self.target_update == 0:self.target_q_net.load_state_dict(self.q_net.state_dict())  # 更新目标网络# 迭代计数+1self.count += 1

3. 案例实现

我们使用 OpenAI 中的重力摆来验证模型,动作是连续型,代表力矩;状态包含三个;目的是让杆子竖直。

 环境交互和训练代码如下:

import torch
import numpy as np
import gym
from tqdm import tqdm
import matplotlib.pyplot as plt
from parsers import args
from RL_brain import ReplayBuffer, Double_DQN# GPU运算
device = torch.device("cuda") if torch.cuda.is_available() \else torch.device("cpu")# ------------------------------- #
#(1)加载环境
# ------------------------------- #env = gym.make("Pendulum-v1", render_mode="human")
n_states = env.observation_space.shape[0]  # 状态数 3
act_low = env.action_space.low  # 最小动作力矩 -2
act_high = env.action_space.high  # 最大动作力矩 +2
n_actions = 11  # 动作是连续的[-2,2],将其离散成11个动作# 确定离散动作区间后,确定其连续动作
def dis_to_con(discrete_action, n_actions):# discrete_action代表动作索引return act_low + (act_high-act_low) * (discrete_action/(n_actions-1))# 实例化经验池
replay_buffer = ReplayBuffer(args.capacity)# 实例化 Double-DQN
agent = Double_DQN(n_states,args.n_hiddens,n_actions,args.lr,args.gamma,args.epsilon,args.target_update,device)# ------------------------------- #
#(2)模型训练
# ------------------------------- #return_list = []  # 记录每次迭代的return,即链上的reward之和
max_q_value = 0  # 最大state_value
max_q_value_list = []  # 保存所有最大的state_valuefor i in range(10):  # 训练几个回合done = False  # 初始,未到达终点state = env.reset()[0]  # 重置环境episode_return = 0  # 记录每回合的returnwith tqdm(total=10, desc='Iteration %d' % i) as pbar:while True:# 状态state时做动作选择,返回索引action = agent.take_action(state)# 平滑处理最大state_valuemax_q_value = agent.max_q_value(state) * 0.005 + \max_q_value * 0.995# 保存每次迭代的最大state_valuemax_q_value_list.append(max_q_value)# 将action的离散索引连续化action_continuous = dis_to_con(action, n_actions)# 环境更新next_state, reward, done, _, _ = env.step(action_continuous)# 添加经验池replay_buffer.add(state, action, reward, next_state, done)# 更新状态state = next_state# 更新每回合的回报episode_return += reward# 如果经验池超数量过阈值时开始训练if replay_buffer.size() > args.min_size:# 在经验池中随机抽样batch组数据s, a, r, ns, d = replay_buffer.sample(args.batch_size)# 构造训练集transitions_dict = {'states': s,'actions': a,'next_states': ns,'rewards': r,'dones': d,}# 模型训练agent.update(transitions_dict)# 到达终点就停止if done is True: break# 保存每回合的returnreturn_list.append(episode_return)pbar.set_postfix({'step':agent.count,'return':'%.3f' % np.mean(return_list[-10:])})pbar.update(1)# ------------------------------- #
#(3)绘图
# ------------------------------- #plt.subplot(121)
plt.plot(return_list)
plt.title('return')
plt.subplot(122)
plt.plot(max_q_value_list)
plt.title('max_q_value')
plt.show()

相关内容

热门资讯

122.(leaflet篇)l... 听老人家说:多看美女会长寿 地图之家总目录(订阅之前建议先查看该博客) 文章末尾处提供保证可运行...
育碧GDC2018程序化大世界... 1.传统手动绘制森林的问题 采用手动绘制的方法的话,每次迭代地形都要手动再绘制森林。这...
育碧GDC2018程序化大世界... 1.传统手动绘制森林的问题 采用手动绘制的方法的话,每次迭代地形都要手动再绘制森林。这...
Vue使用pdf-lib为文件... 之前也写过两篇预览pdf的,但是没有加水印,这是链接:Vu...
PyQt5数据库开发1 4.1... 文章目录 前言 步骤/方法 1 使用windows身份登录 2 启用混合登录模式 3 允许远程连接服...
Android studio ... 解决 Android studio 出现“The emulator process for AVD ...
Linux基础命令大全(上) ♥️作者:小刘在C站 ♥️个人主页:小刘主页 ♥️每天分享云计算网络运维...
再谈解决“因为文件包含病毒或潜... 前面出了一篇博文专门来解决“因为文件包含病毒或潜在的垃圾软件”的问题,其中第二种方法有...
南京邮电大学通达学院2023c... 题目展示 一.问题描述 实验题目1 定义一个学生类,其中包括如下内容: (1)私有数据成员 ①年龄 ...
PageObject 六大原则 PageObject六大原则: 1.封装服务的方法 2.不要暴露页面的细节 3.通过r...
【Linux网络编程】01:S... Socket多进程 OVERVIEWSocket多进程1.Server2.Client3.bug&...
数据结构刷题(二十五):122... 1.122. 买卖股票的最佳时机 II思路:贪心。把利润分解为每天为单位的维度,然后收...
浏览器事件循环 事件循环 浏览器的进程模型 何为进程? 程序运行需要有它自己专属的内存空间࿰...
8个免费图片/照片压缩工具帮您... 继续查看一些最好的图像压缩工具,以提升用户体验和存储空间以及网站使用支持。 无数图像压...
计算机二级Python备考(2... 目录  一、选择题 1.在Python语言中: 2.知识点 二、基本操作题 1. j...
端电压 相电压 线电压 记得刚接触矢量控制的时候,拿到板子,就赶紧去测各种波形,结...
如何使用Python检测和识别... 车牌检测与识别技术用途广泛,可以用于道路系统、无票停车场、车辆门禁等。这项技术结合了计...
带环链表详解 目录 一、什么是环形链表 二、判断是否为环形链表 2.1 具体题目 2.2 具体思路 2.3 思路的...
【C语言进阶:刨根究底字符串函... 本节重点内容: 深入理解strcpy函数的使用学会strcpy函数的模拟实现⚡strc...
Django web开发(一)... 文章目录前端开发1.快速开发网站2.标签2.1 编码2.2 title2.3 标题2.4 div和s...