DQN 强化学习
pytorch比tenserflow简单。
所以我们模仿用tensorflow写的强化学习。
学习资料:
- 本节的全部代码
- Tensorflow 的 100行 DQN 代码
- 我制作的 DQN 动画简介
- 我的 DQN Tensorflow 教程
- 我的 强化学习 教程
- PyTorch 官网
- 论文 Playing Atari with Deep Reinforcement Learning
要点
Torch 是神经网络库, 那么也可以拿来做强化学习, 之前我用另一个强大神经网络库 Tensorflow 来制作了这一个 从浅入深强化学习教程, 你同样也可以用 PyTorch 来实现, 这次我们就举 DQN 的例子, 我对比了我的 Tensorflow DQN 的代码, 发现 PyTorch 写的要简单很多. 如果对 DQN 或者强化学习还没有太多概念, 强烈推荐我的这个DQN动画短片, 让你秒懂DQN. 还有强推这套花了我几个月来制作的强化学习教程!
模块导入和参数设置
这次除了 Torch 自家模块, 我们还要导入 Gym 环境库模块, 如何安装 gym 模块请看这节教程.
- import torch
- import torch.nn as nn
- from torch.autograd import Variable
- import torch.nn.functional as F
- import numpy as np
- import gym
- # 超参数
- BATCH_SIZE = 32
- LR = 0.01 # learning rate
- EPSILON = 0.9 # 最优选择动作百分比
- GAMMA = 0.9 # 奖励递减参数
- TARGET_REPLACE_ITER = 100 # Q 现实网络的更新频率
- MEMORY_CAPACITY = 2000 # 记忆库大小
- env = gym.make('CartPole-v0') # 立杆子游戏
- env = env.unwrapped
- N_ACTIONS = env.action_space.n # 杆子能做的动作
- N_STATES = env.observation_space.shape[0] # 杆子能获取的环境信息数
神经网络
DQN 当中的神经网络模式, 我们将依据这个模式建立两个神经网络, 一个是现实网络 (Target Net), 一个是估计网络 (Eval Net).
- class Net(nn.Module):
- def __init__(self, ):
- super(Net, self).__init__()
- self.fc1 = nn.Linear(N_STATES, 10)
- self.fc1.weight.data.normal_(0, 0.1) # initialization
- self.out = nn.Linear(10, N_ACTIONS)
- self.out.weight.data.normal_(0, 0.1) # initialization
- def forward(self, x):
- x = self.fc1(x)
- x = F.relu(x)
- actions_value = self.out(x)
- return actions_value
DQN体系
简化的 DQN 体系是这样, 我们有两个 net, 有选动作机制, 有存经历机制, 有学习机制.
- class DQN(object):
- def __init__(self):
- # 建立 target net 和 eval net 还有 memory
- def choose_action(self, x):
- # 根据环境观测值选择动作的机制
- return action
- def store_transition(self, s, a, r, s_):
- # 存储记忆
- def learn(self):
- # target 网络更新
- # 学习记忆库中的记忆
- class DQN(object):
- def __init__(self):
- self.eval_net, self.target_net = Net(), Net()
- self.learn_step_counter = 0 # 用于 target 更新计时
- self.memory_counter = 0 # 记忆库记数
- self.memory = np.zeros((MEMORY_CAPACITY, N_STATES * 2 + 2)) # 初始化记忆库
- self.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=LR) # torch 的优化器
- self.loss_func = nn.MSELoss() # 误差公式
- def choose_action(self, x):
- x = Variable(torch.unsqueeze(torch.FloatTensor(x), 0))
- # 这里只输入一个 sample
- if np.random.uniform() < EPSILON: # 选最优动作
- actions_value = self.eval_net.forward(x)
- action = torch.max(actions_value, 1)[1].data.numpy()[0, 0] # return the argmax
- else: # 选随机动作
- action = np.random.randint(0, N_ACTIONS)
- return action
- def store_transition(self, s, a, r, s_):
- transition = np.hstack((s, [a, r], s_))
- # 如果记忆库满了, 就覆盖老数据
- index = self.memory_counter % MEMORY_CAPACITY
- self.memory[index, :] = transition
- self.memory_counter += 1
- def learn(self):
- # target net 参数更新
- if self.learn_step_counter % TARGET_REPLACE_ITER == 0:
- self.target_net.load_state_dict(self.eval_net.state_dict())
- self.learn_step_counter += 1
- # 抽取记忆库中的批数据
- sample_index = np.random.choice(MEMORY_CAPACITY, BATCH_SIZE)
- b_memory = self.memory[sample_index, :]
- b_s = Variable(torch.FloatTensor(b_memory[:, :N_STATES]))
- b_a = Variable(torch.LongTensor(b_memory[:, N_STATES:N_STATES+1].astype(int)))
- b_r = Variable(torch.FloatTensor(b_memory[:, N_STATES+1:N_STATES+2]))
- b_s_ = Variable(torch.FloatTensor(b_memory[:, -N_STATES:]))
- # 针对做过的动作b_a, 来选 q_eval 的值, (q_eval 原本有所有动作的值)
- q_eval = self.eval_net(b_s).gather(1, b_a) # shape (batch, 1)
- q_next = self.target_net(b_s_).detach() # q_next 不进行反向传递误差, 所以 detach
- q_target = b_r + GAMMA * q_next.max(1)[0] # shape (batch, 1)
- loss = self.loss_func(q_eval, q_target)
- # 计算, 更新 eval net
- self.optimizer.zero_grad()
- loss.backward()
- self.optimizer.step()
训练
按照 Qlearning 的形式进行 off-policy 的更新. 我们进行回合制更行, 一个回合完了, 进入下一回合. 一直到他们将杆子立起来很久.
- dqn = DQN() # 定义 DQN 系统
- for i_episode in range(400):
- s = env.reset()
- while True:
- env.render() # 显示实验动画
- a = dqn.choose_action(s)
- # 选动作, 得到环境反馈
- s_, r, done, info = env.step(a)
- # 修改 reward, 使 DQN 快速学习
- x, x_dot, theta, theta_dot = s_
- r1 = (env.x_threshold - abs(x)) / env.x_threshold - 0.8
- r2 = (env.theta_threshold_radians - abs(theta)) / env.theta_threshold_radians - 0.5
- r = r1 + r2
- # 存记忆
- dqn.store_transition(s, a, r, s_)
- if dqn.memory_counter > MEMORY_CAPACITY:
- dqn.learn() # 记忆库满了就进行学习
- if done: # 如果回合结束, 进入下回合
- break
- s = s_
DQN 强化学习的更多相关文章
- 强化学习 - Q-learning Sarsa 和 DQN 的理解
本文用于基本入门理解. 强化学习的基本理论 : R, S, A 这些就不说了. 先设想两个场景: 一. 1个 5x5 的 格子图, 里面有一个目标点, 2个死亡点二. 一个迷宫, 一个出发点, ...
- 强化学习(十二) Dueling DQN
在强化学习(十一) Prioritized Replay DQN中,我们讨论了对DQN的经验回放池按权重采样来优化DQN算法的方法,本文讨论另一种优化方法,Dueling DQN.本章内容主要参考了I ...
- 强化学习(十)Double DQN (DDQN)
在强化学习(九)Deep Q-Learning进阶之Nature DQN中,我们讨论了Nature DQN的算法流程,它通过使用两个相同的神经网络,以解决数据样本和网络训练之前的相关性.但是还是有其他 ...
- 强化学习(十一) Prioritized Replay DQN
在强化学习(十)Double DQN (DDQN)中,我们讲到了DDQN使用两个Q网络,用当前Q网络计算最大Q值对应的动作,用目标Q网络计算这个最大动作对应的目标Q值,进而消除贪婪法带来的偏差.今天我 ...
- 强化学习(九)Deep Q-Learning进阶之Nature DQN
在强化学习(八)价值函数的近似表示与Deep Q-Learning中,我们讲到了Deep Q-Learning(NIPS 2013)的算法和代码,在这个算法基础上,有很多Deep Q-Learning ...
- 强化学习(四)—— DQN系列(DQN, Nature DQN, DDQN, Dueling DQN等)
1 概述 在之前介绍的几种方法,我们对值函数一直有一个很大的限制,那就是它们需要用表格的形式表示.虽说表格形式对于求解有很大的帮助,但它也有自己的缺点.如果问题的状态和行动的空间非常大,使用表格表示难 ...
- 【转】【强化学习】Deep Q Network(DQN)算法详解
原文地址:https://blog.csdn.net/qq_30615903/article/details/80744083 DQN(Deep Q-Learning)是将深度学习deeplearni ...
- 【转载】 强化学习(十一) Prioritized Replay DQN
原文地址: https://www.cnblogs.com/pinard/p/9797695.html ------------------------------------------------ ...
- 【转载】 强化学习(十)Double DQN (DDQN)
原文地址: https://www.cnblogs.com/pinard/p/9778063.html ------------------------------------------------ ...
随机推荐
- 提供SaaS Launchkit,快速定制,一云多端等能力,一云多端将通过小程序云实现
摘要: SaaS加速器的技术中心能力中,将提供SaaS Launchkit,快速定制,一云多端等能力,加速应用上云迁移.降低应用开发和定制的门槛,提升效率.其中非常关键的一云多端能力将通过小程序云实现 ...
- spring boot初步
spring boot介绍 Spring Boot 是由 Pivotal 团队提供的全新框架,其设计目的是用来简化新 Spring 应用的初始搭建以及开发过程. 该框架使用了特定的方式来进行配置,从而 ...
- 从零学React Native之05混合开发
本篇文章,我们主要讨论如何实现Android平台的混合开发. RN给Android端发送消息 首先打开Android Studio, Open工程, 在React Native项目目录下选择andro ...
- laravel 的路由中间件
简介# Laravel 中间件提供了一种方便的机制来过滤进入应用的HTTP请求.例如,Laravel 内置了一个中间件来验证用户的身份认证 , 如果没有通过身份认证,中间件会将用户重定向到登陆界面,但 ...
- Python中进制转换函数的使用
Python中进制转换函数的使用 关于Python中几个进制转换的函数使用方法,做一个简单的使用方法的介绍,我们常用的进制转换函数常用的就是int()(其他进制转换到十进制).bin()(十进制转换到 ...
- python常量和变量
1.1 常量 常量是内存中用于保存固定值的单元,在程序中常量的值不能发生改变:python并没有命名常量,也就是说不能像C语言那样给常量起一个名字. python常量包括:数字.字符串.布尔值.空值: ...
- selenium webdriver学习(八)------------如何操作select下拉框(转)
selenium webdriver学习(八)------------如何操作select下拉框 博客分类: Selenium-webdriver 下面我们来看一下selenium webdriv ...
- mysql 字段名和关键字冲突
用"(`)"将有冲突的字段框起来,,键盘上1边上那个键. 例: SELECT * FROM yun_roleright WHERE right LIKE '%{13}%'; 上面s ...
- iptables 伪装(Masquerading)
「偽装」是一种特殊的SNAT操作:将来自其它电脑的包的来源位址改成自己的位址:请注意,由於入替的来源位址是自动決定的(执行SNAT的主机的IP位址).所以,如果它改变了,仍在持续中的旧连線将会失效.「 ...
- 创建JAVASCRIPT对象3种方法
创建JAVASCRIPT对象3种方法 方法一:直接定义并创建对象实例 var obj = new Object(); //创建对象实例 //添加属性obj.num = 5; //添加属性 o ...