DQN(Deep Q-learning)入门教程(四)之Q-learning Play Flappy Bird中,我们使用q-learning算法去对Flappy Bird进行强化学习,而在这篇博客中我们将使用神经网络模型来代替Q-table,关于DQN的介绍,可以参考我前一篇博客:DQN(Deep Q-learning)入门教程(五)之DQN介绍

在这篇博客中将使用DQN做如下操作:

  • Flappy Bird
  • MountainCar-v0

再回顾一下DQN的算法流程:

项目地址:Github

MountainCar-v0

MountainCar的训练好的Gif示意图如下所示,汽车起始位置位于山的底部,最终目标是驶向右边山的插旗的地方,其中,汽车的引擎不能够直接驶向终点,必须借助左边的山体的重力加速度才能够驶向终点。

MountainCar-v0由OpenAI提供,python包为gym,官网网站为https://gym.openai.com/envs/MountainCar-v0/。在Gym包中,提供了很多可以用于强化学习的环境(env):

在MountainCar-v0中,状态有2个变量,car position(汽车的位置),car vel(汽车的速度),action一共有3种 Accelerate to the Left Don't accelerateAccelerate to the Right,然后当车达到旗帜的地方(position = 0.5)会得到\(reward = 1\)的奖励,如果没有达到则为\(-1\)。但是如果当你运行步骤超过200次的时候,游戏就会结束。详情可以参考源代码(ps:官方文档中没有这些说明)。

下面介绍一下gym中几个常用的函数:

    1. env = gym.make("MountainCar-v0")

    这个就是创建一个MountainCar-v0的游戏环境。

    1. state = env.reset()

    重置环境,返回重置后的state

    1. env.render()

    将运行画面展示在屏幕上面,当我们在训练的时候可以不使用这个来提升速度。

    1. next_state, reward, done, _ = env.step(action)

    执行action动作,返回下一个状态,奖励,是否完成,info。

初始化Agent

初始化Agent直接使用代码说明吧,这个还是比较简单的:

  1. import keras
  2. import random
  3. from collections import deque
  4. import gym
  5. import numpy as np
  6. from keras.layers import Dense
  7. from keras.models import Sequential
  8. class Agent():
  9. def __init__(self, action_set, observation_space):
  10. """
  11. 初始化
  12. :param action_set: 动作集合
  13. :param observation_space: 环境属性,我们需要使用它得到state的shape
  14. """
  15. # 奖励衰减
  16. self.gamma = 1.0
  17. # 从经验池中取出数据的数量
  18. self.batch_size = 50
  19. # 经验池
  20. self.memory = deque(maxlen=2000000)
  21. # 探索率
  22. self.greedy = 1.0
  23. # 动作集合
  24. self.action_set = action_set
  25. # 环境的属性
  26. self.observation_space = observation_space
  27. # 神经网路模型
  28. self.model = self.init_netWork()
  29. def init_netWork(self):
  30. """
  31. 构建模型
  32. :return: 模型
  33. """
  34. model = Sequential()
  35. # self.observation_space.shape[0],state的变量的数量
  36. model.add(Dense(64 * 4, activation="tanh", input_dim=self.observation_space.shape[0]))
  37. model.add(Dense(64 * 4, activation="tanh"))
  38. # self.action_set.n 动作的数量
  39. model.add(Dense(self.action_set.n, activation="linear"))
  40. model.compile(loss=keras.losses.mean_squared_error,
  41. optimizer=keras.optimizers.RMSprop(lr=0.001))
  42. return model

我们使用队列来保存经验,这样的话新的数据就会覆盖远古的数据。此时我们定义一个函数,专门用来将数据保存到经验池中,然后定义一个函数用来更新\(\epsilon\)探索率。

  1. def add_memory(self, sample):
  2. self.memory.append(sample)
  3. def update_greedy(self):
  4. # 小于最小探索率的时候就不进行更新了。
  5. if self.greedy > 0.01:
  6. self.greedy *= 0.995

训练模型

首先先看代码:

  1. def train_model(self):
  2. # 从经验池中随机选择部分数据
  3. train_sample = random.sample(self.memory, k=self.batch_size)
  4. train_states = []
  5. next_states = []
  6. for sample in train_sample:
  7. cur_state, action, r, next_state, done = sample
  8. next_states.append(next_state)
  9. train_states.append(cur_state)
  10. # 转成np数组
  11. next_states = np.array(next_states)
  12. train_states = np.array(train_states)
  13. # 得到next_state的q值
  14. next_states_q = self.model.predict(next_states)
  15. # 得到state的预测值
  16. state_q = self.model.predict_on_batch(train_states)
  17. # 计算Q现实
  18. for index, sample in enumerate(train_sample):
  19. cur_state, action, r, next_state, done = sample
  20. if not done:
  21. state_q[index][action] = r + self.gamma * np.max(next_states_q[index])
  22. else:
  23. state_q[index][action] = r
  24. self.model.train_on_batch(train_states, state_q)

大家肯定从上面的代码发现一些问题,使用了两个for循环,why?首先先说一下两个for循环分别的作用:

  • 第一个for循环:得到train_statesnext_states,其中next_states是为了计算Q现实。
  • 第二个for循环:计算Q现实

可能有人会有一个疑问,为什么我不写成一个for循环呢?实际上写成一个for循环是完全没有问题的,很,但是写成一个for循环意味着我们要多次调用model.predict_on_batch,这样会耗费一定的时间(亲身试验过,这样会比较慢),因此,我们写成了两个for循环,然后只需要调用一次predict

执行动作与选择最佳动作

执行动作的代码如下所示:

  1. def act(self, env, action):
  2. """
  3. 执行动作
  4. :param env: 执行环境
  5. :param action: 执行的动作
  6. :return: ext_state, reward, done
  7. """
  8. next_state, reward, done, _ = env.step(action)
  9. if done:
  10. if reward < 0:
  11. reward = -100
  12. else:
  13. reward = 10
  14. else:
  15. if next_state[0] >= 0.4:
  16. reward += 1
  17. return next_state, reward, done

其中,我们可以修改奖励以加快网络收敛。

选择最好的动作的动作如下所示,会以一定的探索率随机选择动作。

  1. def get_best_action(self, state):
  2. if random.random() < self.greedy:
  3. return self.action_set.sample()
  4. else:
  5. return np.argmax(self.model.predict(state.reshape(-1, 2)))

开始训练

关于具体的解释,在注释中已经详细的说明了:

  1. if __name__ == "__main__":
  2. # 训练次数
  3. episodes = 10000
  4. # 实例化游戏环境
  5. env = gym.make("MountainCar-v0")
  6. # 实例化Agent
  7. agent = Agent(env.action_space, env.observation_space)
  8. # 游戏中动作执行的次数(最大为200)
  9. counts = deque(maxlen=10)
  10. for episode in range(episodes):
  11. count = 0
  12. # 重置游戏
  13. state = env.reset()
  14. # 刚开始不立即更新探索率
  15. if episode >= 5:
  16. agent.update_greedy()
  17. while True:
  18. count += 1
  19. # 获得最佳动作
  20. action = agent.get_best_action(state)
  21. next_state, reward, done = agent.act(env, action)
  22. agent.add_memory((state, action, reward, next_state, done))
  23. # 刚开始不立即训练模型,先填充经验池
  24. if episode >= 5:
  25. agent.train_model()
  26. state = next_state
  27. if done:
  28. # 将执行的次数添加到counts中
  29. counts.append(count)
  30. print("在{}轮中,agent执行了{}次".format(episode + 1, count))
  31. # 如果近10次,动作执行的平均次数少于160,则保存模型并退出
  32. if len(counts) == 10 and np.mean(counts) < 160:
  33. agent.model.save("car_model.h5")
  34. exit(0)
  35. break

训练一定的次数后,我们就可以得到模型了。然后进行测试。

模型测试

测试的代码没什么好说的,如下所示:

  1. import gym
  2. from keras.models import load_model
  3. import numpy as np
  4. model = load_model("car_model.h5")
  5. env = gym.make("MountainCar-v0")
  6. for i in range(100):
  7. state = env.reset()
  8. count = 0
  9. while True:
  10. env.render()
  11. count += 1
  12. action = np.argmax(model.predict(state.reshape(-1, 2)))
  13. next_state, reward, done, _ = env.step(action)
  14. state = next_state
  15. if done:
  16. print("游戏的次数:", count)
  17. break

部分的结果如下:

Flappy Bird

FlappyBird的代码我就不过多赘述了,里面的一些函数介绍可以参照这个来看:DQN(Deep Q-learning)入门教程(四)之Q-learning Play Flappy Bird,代码思想与训练Mountain-Car基本是一致的。

  1. import random
  2. from collections import deque
  3. import keras
  4. import numpy as np
  5. from keras.layers import Dense
  6. from keras.models import Sequential
  7. from ple import PLE
  8. from ple.games import FlappyBird
  9. class Agent():
  10. def __init__(self, action_set):
  11. self.gamma = 1
  12. self.model = self.init_netWork()
  13. self.batch_size = 128
  14. self.memory = deque(maxlen=2000000)
  15. self.greedy = 1
  16. self.action_set = action_set
  17. def get_state(self, state):
  18. """
  19. 提取游戏state中我们需要的数据
  20. :param state: 游戏state
  21. :return: 返回提取好的数据
  22. """
  23. return_state = np.zeros((3,))
  24. dist_to_pipe_horz = state["next_pipe_dist_to_player"]
  25. dist_to_pipe_bottom = state["player_y"] - state["next_pipe_top_y"]
  26. velocity = state['player_vel']
  27. return_state[0] = dist_to_pipe_horz
  28. return_state[1] = dist_to_pipe_bottom
  29. return_state[2] = velocity
  30. return return_state
  31. def init_netWork(self):
  32. """
  33. 构建模型
  34. :return:
  35. """
  36. model = Sequential()
  37. model.add(Dense(64 * 4, activation="tanh", input_shape=(3,)))
  38. model.add(Dense(64 * 4, activation="tanh"))
  39. model.add(Dense(2, activation="linear"))
  40. model.compile(loss=keras.losses.mean_squared_error,
  41. optimizer=keras.optimizers.RMSprop(lr=0.001))
  42. return model
  43. def train_model(self):
  44. if len(self.memory) < 2500:
  45. return
  46. train_sample = random.sample(self.memory, k=self.batch_size)
  47. train_states = []
  48. next_states = []
  49. for sample in train_sample:
  50. cur_state, action, r, next_state, done = sample
  51. next_states.append(next_state)
  52. train_states.append(cur_state)
  53. # 转成np数组
  54. next_states = np.array(next_states)
  55. train_states = np.array(train_states)
  56. # 得到下一个state的q值
  57. next_states_q = self.model.predict(next_states)
  58. # 得到预测值
  59. state_q = self.model.predict_on_batch(train_states)
  60. for index, sample in enumerate(train_sample):
  61. cur_state, action, r, next_state, done = sample
  62. # 计算Q现实
  63. if not done:
  64. state_q[index][action] = r + self.gamma * np.max(next_states_q[index])
  65. else:
  66. state_q[index][action] = r
  67. self.model.train_on_batch(train_states, state_q)
  68. def add_memory(self, sample):
  69. self.memory.append(sample)
  70. def update_greedy(self):
  71. if self.greedy > 0.01:
  72. self.greedy *= 0.995
  73. def get_best_action(self, state):
  74. if random.random() < self.greedy:
  75. return random.randint(0, 1)
  76. else:
  77. return np.argmax(self.model.predict(state.reshape(-1, 3)))
  78. def act(self, p, action):
  79. """
  80. 执行动作
  81. :param p: 通过p来向游戏发出动作命令
  82. :param action: 动作
  83. :return: 奖励
  84. """
  85. r = p.act(self.action_set[action])
  86. if r == 0:
  87. r = 1
  88. if r == 1:
  89. r = 100
  90. else:
  91. r = -1000
  92. return r
  93. if __name__ == "__main__":
  94. # 训练次数
  95. episodes = 20000
  96. # 实例化游戏对象
  97. game = FlappyBird()
  98. # 类似游戏的一个接口,可以为我们提供一些功能
  99. p = PLE(game, fps=30, display_screen=False)
  100. # 初始化
  101. p.init()
  102. # 实例化Agent,将动作集传进去
  103. agent = Agent(p.getActionSet())
  104. max_score = 0
  105. scores = deque(maxlen=10)
  106. for episode in range(episodes):
  107. # 重置游戏
  108. p.reset_game()
  109. # 获得状态
  110. state = agent.get_state(game.getGameState())
  111. if episode > 150:
  112. agent.update_greedy()
  113. while True:
  114. # 获得最佳动作
  115. action = agent.get_best_action(state)
  116. # 然后执行动作获得奖励
  117. reward = agent.act(p, action)
  118. # 获得执行动作之后的状态
  119. next_state = agent.get_state(game.getGameState())
  120. agent.add_memory((state, action, reward, next_state, p.game_over()))
  121. agent.train_model()
  122. state = next_state
  123. if p.game_over():
  124. # 获得当前分数
  125. current_score = p.score()
  126. max_score = max(max_score, current_score)
  127. scores.append(current_score)
  128. print('第%s次游戏,得分为: %s,最大得分为: %s' % (episode, current_score, max_score))
  129. if len(scores) == 10 and np.mean(scores) > 150:
  130. agent.model.save("bird_model.h5")
  131. exit(0)
  132. break

该部分相比较于Mountain-Car需要更长的时间,目前的我还没有训练出比较好的效果,截至写完这篇博客,最新的数据如下所示:

emm,我又不想让我的电脑一直开着,。

总结

上面的两个例子便是DQN最基本最基本的使用,我们还可以将上面的FlappyBird的问题稍微复杂化一点,比如说我们无法直接的知道环境的状态,我们则可以使用CNN网络去从游戏图片入手(关于这种做法,网络上有很多人写了相对应的博客)。

项目地址:Github

参考

DQN(Deep Q-learning)入门教程(六)之DQN Play Flappy-bird ,MountainCar的更多相关文章

  1. DQN(Deep Q-learning)入门教程(三)之蒙特卡罗法算法与Q-learning算法

    蒙特卡罗法 在介绍Q-learing算法之前,我们还是对蒙特卡罗法(MC)进行一些介绍.MC方法是一种无模型(model-free)的强化学习方法,目标是得到最优的行为价值函数\(q_*\).在前面一 ...

  2. DQN(Deep Q-learning)入门教程(五)之DQN介绍

    简介 DQN--Deep Q-learning.在上一篇博客DQN(Deep Q-learning)入门教程(四)之Q-learning Play Flappy Bird 中,我们使用Q-Table来 ...

  3. DQN(Deep Q-learning)入门教程(二)之最优选择

    在上一篇博客:DQN(Deep Q-learning)入门教程(一)之强化学习介绍中有三个很重要的函数: 策略:\(\pi(a|s) = P(A_t=a | S_t=s)\) 状态价值函数:\(v_\ ...

  4. 无废话ExtJs 入门教程六[按钮:Button]

    无废话ExtJs 入门教程六[按钮:Button] extjs技术交流,欢迎加群(201926085) 继上一节内容,我们在表单里加了个两个按钮“提交”与重置.如下所示代码区的第68行位置, butt ...

  5. PySide——Python图形化界面入门教程(六)

    PySide——Python图形化界面入门教程(六) ——QListView和QStandardItemModel 翻译自:http://pythoncentral.io/pyside-pyqt-tu ...

  6. Elasticsearch入门教程(六):Elasticsearch查询(二)

    原文:Elasticsearch入门教程(六):Elasticsearch查询(二) 版权声明:本文为博主原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明. 本文链接:h ...

  7. RabbitMQ入门教程(六):路由选择Routing

    原文:RabbitMQ入门教程(六):路由选择Routing 版权声明:本文为博主原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明. 本文链接:https://blog. ...

  8. DQN(Deep Q-learning)入门教程(零)之教程介绍

    简介 DQN入门系列地址:https://www.cnblogs.com/xiaohuiduan/category/1770037.html 本来呢,在上一个系列数据挖掘入门系列博客中,我是准备写数据 ...

  9. DQN(Deep Q-learning)入门教程(四)之Q-learning Play Flappy Bird

    在上一篇博客中,我们详细的对Q-learning的算法流程进行了介绍.同时我们使用了\(\epsilon-贪婪法\)防止陷入局部最优. 那么我们可以想一下,最后我们得到的结果是什么样的呢?因为我们考虑 ...

随机推荐

  1. P1640 连续攻击游戏

    题目传送门 Ⅰ.二分图匹配 其实这题应该不难看出是二分图匹配(尽管我没看出来) 每个物品只能用一次,实际上就是1~n的数字对物品的最大匹配 把物品的两个属性向物品编号连边,之后就从数字1一直匹配过去 ...

  2. 在使用SSH+Spring开发webservice ,报的一些异常及处理方法

    1.No bean named 'cxf' is defined 配置文件被我分成了三份,启动时忘记将webService配置导入到主文件,修改后如下: 2.bad request 400 访问路径写 ...

  3. 使用EF Code First生成模型,如何让时间字段由数据库自动生成

    场景:保存记录时需要时间字段,该时间如果由前台通过DateTime.Now产生,存在风险,比如修改客户端的系统时间,就会伪造该记录的生成时间.因此,需要在保存记录时,由后台自动赋予具体的时间. 实现方 ...

  4. Day_08【面向对象】扩展案例3_使用多态的形式创建缉毒狗对象,调用缉毒方法和吼叫方法

    分析以下需求,并用代码实现: 1.定义动物类: 行为: 吼叫:没有具体的吼叫行为 吃饭:没有具体的吃饭行为 2.定义缉毒接口 行为: 缉毒 3.定义缉毒狗:犬的一种 行为: 吼叫:汪汪叫 吃饭:狗啃骨 ...

  5. 【hdu1006】解方程

    http://acm.hdu.edu.cn/showproblem.php?pid=1006 这题坑了我好久,发现居然是一个除法变成了整除,TAT,所以建议在写较长的运算表达式的时候出现了除法尽量加个 ...

  6. hive数据仓库入门到实战及面试

    第一章.hive入门 一.hive入门手册 1.什么是数据仓库 1.1数据仓库概念 对历史数据变化的统计,从而支撑企业的决策.比如:某个商品最近一个月的销量,预判下个月应该销售多少,从而补充多少货源. ...

  7. sqli-labs之Page-4

    第五十四关 题目给出了数据库名为challenges. 这一关是依旧字符型注入,但是尝试10次后,会强制更换表名等信息.所以尽量在认真思考后进行尝试 爆表名 ?id=-1' union select ...

  8. springmvc 校验--JSR

    1.使用JSR规范是简单使用的,如果使用hibernate校验则需要在工程中添加hibernate-validate.jar,以及其他依赖的jar包. 2,在mvc配置文件中使用<mvc:ann ...

  9. 2020网鼎杯 白虎组reverse:hero

    主函数,当bossexist的值不为0时,while循环dround()函数,循环结束输出flag outflag()函数的flag值由6段数据拼凑而成 while循环的dround()函数有三个选择 ...

  10. 什么是 Nginx?

    Nginx (engine x) 是一款轻量级的 Web 服务器 .反向代理服务器及电子邮件(IMAP/POP3)代理服务器. 什么是反向代理? 反向代理(Reverse Proxy)方式是指以代理服 ...