将例二改写成面向对象模式,并加了环境!

不过更新环境的过程中,用到了清屏命令,play()的时候,会有点问题。learn()的时候可以勉强看到:P

0.效果图

1.完整代码

相对于例一,修改的地方:

Agent 五处:states, actions, rewards, get_valid_actions(), get_next_state()

Env    两处:__init__(), update()

  1. import pandas as pd
  2. import random
  3. import time
  4. import pickle
  5. import pathlib
  6. import os
  7.  
  8. '''
  9. 四格迷宫:
    ---------------
    | 入口 |      |
    ---------------
    | 陷阱 | 出口 |
    ---------------
  10. '''
  11.  
  12. class Env(object):
  13. '''环境类'''
  14. def __init__(self):
  15. '''初始化'''
  16. self.env = list('--\n#-')
  17.  
  18. def update(self, state, delay=0.1):
  19. '''更新环境,并打印'''
  20. env = self.env[:]
  21. if state > 1:
  22. state += 1
  23. env[state] = 'o' # 更新环境
  24. print('\r{}'.format(''.join(env)), end='')
  25. time.sleep(delay)
  26. os.system('cls')
  27.  
  28. class Agent(object):
  29. '''个体类'''
  30. def __init__(self, alpha=0.01, gamma=0.9):
  31. '''初始化'''
  32. self.states = range(4) # 状态集。0, 1, 2, 3 四个状态
  33. self.actions = list('udlr') # 动作集。上下左右 4个动作
  34. self.rewards = [0,0,-10,10] # 奖励集。到达位置3(出口)奖励10,位置2(陷阱)奖励-10,其他皆为0
  35.  
  36. self.alpha = alpha
  37. self.gamma = gamma
  38.  
  39. self.q_table = pd.DataFrame(data=[[0 for _ in self.actions] for _ in self.states],
  40. index=self.states,
  41. columns=self.actions)
  42.  
  43. def save_policy(self):
  44. '''保存Q table'''
  45. with open('q_table.pickle', 'wb') as f:
  46. # Pickle the 'data' dictionary using the highest protocol available.
  47. pickle.dump(self.q_table, f, pickle.HIGHEST_PROTOCOL)
  48.  
  49. def load_policy(self):
  50. '''导入Q table'''
  51. with open('q_table.pickle', 'rb') as f:
  52. self.q_table = pickle.load(f)
  53.  
  54. def choose_action(self, state, epsilon=0.8):
  55. '''选择相应的动作。根据当前状态,随机或贪婪,按照参数epsilon'''
  56. if (random.uniform(0,1) > epsilon) or ((self.q_table.ix[state] == 0).all()): # 探索
  57. action = random.choice(self.get_valid_actions(state))
  58. else:
  59. action = self.q_table.ix[state].idxmax() # 利用(贪婪)
  60. return action
  61.  
  62. def get_q_values(self, state):
  63. '''取状态state的所有Q value'''
  64. q_values = self.q_table.ix[state, self.get_valid_actions(state)]
  65. return q_values
  66.  
  67. def update_q_value(self, state, action, next_state_reward, next_state_q_values):
  68. '''更新Q value,根据贝尔曼方程'''
  69. self.q_table.ix[state, action] += self.alpha * (next_state_reward + self.gamma * next_state_q_values.max() - self.q_table.ix[state, action])
  70.  
  71. def get_valid_actions(self, state):
  72. '''取当前状态下所有的合法动作'''
  73. valid_actions = set(self.actions)
  74. if state % 2 == 1: # 最后一列,则
  75. valid_actions -= set(['r']) # 无向右的动作
  76. if state % 2 == 0: # 最前一列,则
  77. valid_actions -= set(['l']) # 无向左
  78. if state // 2 == 1: # 最后一行,则
  79. valid_actions -= set(['d']) # 无向下
  80. if state // 2 == 0: # 最前一行,则
  81. valid_actions -= set(['u']) # 无向上
  82. return list(valid_actions)
  83.  
  84. def get_next_state(self, state, action):
  85. '''对状态执行动作后,得到下一状态'''
  86. #u,d,l,r,n = -2,+2,-1,+1,0
  87. if state % 2 != 1 and action == 'r': # 除最后一列,皆可向右(+1)
  88. next_state = state + 1
  89. elif state % 2 != 0 and action == 'l': # 除最前一列,皆可向左(-1)
  90. next_state = state -1
  91. elif state // 2 != 1 and action == 'd': # 除最后一行,皆可向下(+2)
  92. next_state = state + 2
  93. elif state // 2 != 0 and action == 'u': # 除最前一行,皆可向上(-2)
  94. next_state = state - 2
  95. else:
  96. next_state = state
  97. return next_state
  98.  
  99. def learn(self, env=None, episode=1000, epsilon=0.8):
  100. '''q-learning算法'''
  101. print('Agent is learning...')
  102. for _ in range(episode):
  103. current_state = self.states[0]
  104.  
  105. if env is not None: # 若提供了环境,则更新之!
  106. env.update(current_state)
  107.  
  108. while current_state != self.states[-1]:
  109. current_action = self.choose_action(current_state, epsilon) # 按一定概率,随机或贪婪地选择
  110. next_state = self.get_next_state(current_state, current_action)
  111. next_state_reward = self.rewards[next_state]
  112. next_state_q_values = self.get_q_values(next_state)
  113. self.update_q_value(current_state, current_action, next_state_reward, next_state_q_values)
  114. current_state = next_state
  115.  
  116. if env is not None: # 若提供了环境,则更新之!
  117. env.update(current_state)
  118. print('\nok')
  119.  
  120. def play(self, env=None, delay=0.5):
  121. '''玩游戏,使用策略'''
  122. assert env != None, 'Env must be not None!'
  123.  
  124. if pathlib.Path("q_table.pickle").exists():
  125. self.load_policy()
  126. else:
  127. print("I need to learn before playing this game.")
  128. self.learn(env, 13)
  129. self.save_policy()
  130.  
  131. print('Agent is playing...')
  132. current_state = self.states[0]
  133. env.update(current_state, delay)
  134. while current_state != self.states[-1]:
  135. current_action = self.choose_action(current_state, 1.) # 1., 不随机
  136. next_state = self.get_next_state(current_state, current_action)
  137. current_state = next_state
  138. env.update(current_state, delay)
  139. print('\nCongratulations, Agent got it!')
  140.  
  141. if __name__ == '__main__':
  142. env = Env() # 环境
  143. agent = Agent() # 个体
  144. agent.learn(env, episode=25, epsilon=0.6) # 先学
  145. #agent.save_policy() # 保存所学
  146. #agent.load_policy() # 导入所学
  147. #agent.play(env) # 再玩

【强化学习】python 实现 q-learning 例四(例二改写)的更多相关文章

  1. 深度强化学习(DQN-Deep Q Network)之应用-Flappy Bird

    深度强化学习(DQN-Deep Q Network)之应用-Flappy Bird 本文系作者原创,转载请注明出处:https://www.cnblogs.com/further-further-fu ...

  2. web前端学习python之第一章_基础语法(二)

    web前端学习python之第一章_基础语法(二) 前言:最近新做了一个管理系统,前端已经基本完成, 但是后端人手不足没人给我写接口,自力更生丰衣足食, 所以决定自学python自己给自己写接口哈哈哈 ...

  3. 机器学习之强化学习概览(Machine Learning for Humans: Reinforcement Learning)

    声明:本文翻译自Vishal Maini在Medium平台上发布的<Machine Learning for Humans>的教程的<Part 5: Reinforcement Le ...

  4. 深度强化学习(Deep Reinforcement Learning)入门:RL base & DQN-DDPG-A3C introduction

    转自https://zhuanlan.zhihu.com/p/25239682 过去的一段时间在深度强化学习领域投入了不少精力,工作中也在应用DRL解决业务问题.子曰:温故而知新,在进一步深入研究和应 ...

  5. 【转】【强化学习】Deep Q Network(DQN)算法详解

    原文地址:https://blog.csdn.net/qq_30615903/article/details/80744083 DQN(Deep Q-Learning)是将深度学习deeplearni ...

  6. Deep learning:四十二(Denoise Autoencoder简单理解)

    前言: 当采用无监督的方法分层预训练深度网络的权值时,为了学习到较鲁棒的特征,可以在网络的可视层(即数据的输入层)引入随机噪声,这种方法称为Denoise Autoencoder(简称dAE),由Be ...

  7. 廖雪峰网站:学习python基础知识—循环(四)

    一.循环 1.for names = ['Michal', 'Bob', 'tracy'] for name in names: print(name) sum = 0 for x in [1, 2, ...

  8. 廖雪峰网站:学习python基础知识—list和tuple(二)

    1.list """ Python内置的一种数据类型是列表:list. list是一种有序的集合,可以随时添加和删除其中的元素. """ c ...

  9. [Reinforcement Learning] 强化学习介绍

    随着AlphaGo和AlphaZero的出现,强化学习相关算法在这几年引起了学术界和工业界的重视.最近也翻了很多强化学习的资料,有时间了还是得自己动脑筋整理一下. 强化学习定义 先借用维基百科上对强化 ...

  10. Deep Learning专栏--强化学习之从 Policy Gradient 到 A3C(3)

    在之前的强化学习文章里,我们讲到了经典的MDP模型来描述强化学习,其解法包括value iteration和policy iteration,这类经典解法基于已知的转移概率矩阵P,而在实际应用中,我们 ...

随机推荐

  1. tomcat 7.0 最大连接数和线程设置

    部署项目时需要根据服务器配置调整连接数 <Connector port="8080" protocol="HTTP/1.1"URIEncoding=&qu ...

  2. [cb]扩展Hierarchy 添加二级菜单

    目地 这篇博客教大家如何扩展Hierarchy 默认的Hierarchy 在Unity的Edior编辑器中,默认的Hierarchy如下 扩展的Hierarchy 扩展示例 MyInitOnLoad脚 ...

  3. django —— MVT模型

    转载----

  4. 17秋 软件工程 第六次作业 Beta冲刺 Scrum5

    17秋 软件工程 第六次作业 Beta冲刺 Scrum5 各个成员冲刺期间完成的任务 世强:完成APP端相册.部员管理.手势签到模块: 陈翔:完成Scrum博客.总结博客,完成超级管理员前后端对接: ...

  5. asp.net core中使用HttpClient实现Post和Get的同步异步方法

     准备工作 1.visual studio 2015 update3开发环境 2.net core 1.0.1 及以上版本  目录 1.HttpGet方法 2.HttpPost方法 3.使用示例 4. ...

  6. 关于HashMap自定义key重写hashCode和equals的问题

     使用HashMap,如果key是自定义的类,就必须重写hashcode()和equals() hashcode()和equals()都继承于object,在Object类中的定义为: equals( ...

  7. 如何取得select结果数据集的前10条记录。postgresql

    我用的是POSTGRESQL.select name from t_personal order by personal_id desc 我想取得上面结果数据的,前10条记录.SQL语句怎么改. 我记 ...

  8. P1019 单词接龙

    单词接龙是一个与我们经常玩的成语接龙相类似的游戏,现在我们已知一组单词,且给定一个开头的字母,要求出以这个字母开头的最长的“龙”(每个单词都最多在“龙”中出现两次),在两个单词相连时,其重合部分合为一 ...

  9. read_csv报错Initializing from file failed

    Python版本:Python 3.6 pandas.read_csv() 报错 OSError: Initializing from file failed,一般由两种情况引起:一种是函数参数为路径 ...

  10. 转://tcpdump抓包实例

    基本语法 ========过滤主机--------- 抓取所有经过 eth1,目的或源地址是 192.168.1.1 的网络数据# tcpdump -i eth1 host 192.168.1.1- ...