强化学习实战 | 表格型Q-Learning玩井字棋(二)开始训练!中,我们让agent“简陋地”训练了起来,经过了耗费时间的10万局游戏过后,却效果平平,尤其是初始状态的数值表现和预期相差不小。我想主要原因就是没有采用等价局面同步更新的方法,导致数据利用率较低。等价局面有7个,分别是:旋转90°,旋转180°,旋转270°,水平翻转,垂直翻转,旋转90°+水平翻转,旋转90°+垂直翻转,如下图所示。另外,在生成等价局面的同时,也要生成等价的动作,这样才能实现完整的Q值更新。

步骤1:写旋转和翻转函数

def rotate(array): # Input: np.array [[1,2,3],[4,5,6],[7,8,9]]
list_ = list(array)
list_[:] = map(list,zip(*list_[::-1]))
return np.array(list_) # Output: np.array [[7,4,1],[8,5,2],[9,6,3]] def flip(array_, direction): # Input: np.array [[1,2,3],[4,5,6],[7,8,9]]
array = array_.copy()
n = int(np.floor(len(array)/2))
if direction == 'vertical': # Output: np.array [[7,8,9],[4,5,6],[1,2,3]]
for i in range(n):
temp = array[i].copy()
array[i] = array[-i-1].copy()
array[-i-1] = temp
elif direction == 'horizon': # Output: np.array [[3,2,1],[6,5,4],[9,8,7]]
for i in range(n):
temp = array[:,i].copy()
array[:,i] = array[:,-i-1]
array[:,-i-1] = temp
return array

步骤2:写生成等价局面及等价动作的函数

函数名为 genEqualStateAndAction(state, action),定义在 Agent() 类中。

def genEqualStateAndAction(self, state_, action_): # Input: np.array, tuple(x,y)
state, action = state_.copy(), action_
equalStates, equalActions = [], [] # 原局面
equalStates.append(state)
equalActions.append(action) # 水平翻转
state_tf = state.copy()
action_state_tf = np.zeros(state.shape)
action_state_tf[action] = 1
state_tf = flip(state_tf, 'horizon')
action_state_tf = flip(action_state_tf, 'horizon')
index = np.where(action_state_tf == 1)
action_tf = (int(index[0]), int(index[1]))
equalStates.append(state_tf)
equalActions.append(action_tf) # 垂直翻转
state_tf = state.copy()
action_state_tf = np.zeros(state.shape)
action_state_tf[action] = 1
state_tf = flip(state_tf, 'vertical')
action_state_tf = flip(action_state_tf, 'vertical')
index = np.where(action_state_tf == 1)
action_tf = (int(index[0]), int(index[1]))
equalStates.append(state_tf)
equalActions.append(action_tf) # 旋转90°
state_tf = state.copy()
action_state_tf = np.zeros(state.shape)
action_state_tf[action] = 1
for i in range(1):
state_tf = rotate(state_tf)
action_state_tf = rotate(action_state_tf)
index = np.where(action_state_tf == 1)
action_tf = (int(index[0]), int(index[1]))
equalStates.append(state_tf)
equalActions.append(action_tf) # 旋转180°
state_tf = state.copy()
action_state_tf = np.zeros(state.shape)
action_state_tf[action] = 1
for i in range(2):
state_tf = rotate(state_tf)
action_state_tf = rotate(action_state_tf)
index = np.where(action_state_tf == 1)
action_tf = (int(index[0]), int(index[1]))
equalStates.append(state_tf)
equalActions.append(action_tf) # 旋转270°
state_tf = state.copy()
action_state_tf = np.zeros(state.shape)
action_state_tf[action] = 1
for i in range(3):
state_tf = rotate(state_tf)
action_state_tf = rotate(action_state_tf)
index = np.where(action_state_tf == 1)
action_tf = (int(index[0]), int(index[1]))
equalStates.append(state_tf)
equalActions.append(action_tf) # 旋转90° + 水平翻转
state_tf = state.copy()
action_state_tf = np.zeros(state.shape)
action_state_tf[action] = 1
for i in range(1):
state_tf = rotate(state_tf)
action_state_tf = rotate(action_state_tf)
state_tf = flip(state_tf, 'horizon')
action_state_tf = flip(action_state_tf, 'horizon')
index = np.where(action_state_tf == 1)
action_tf = (int(index[0]), int(index[1]))
equalStates.append(state_tf)
equalActions.append(action_tf) # 旋转90° + 垂直翻转
state_tf = state.copy()
action_state_tf = np.zeros(state.shape)
action_state_tf[action] = 1
for i in range(1):
state_tf = rotate(state_tf)
action_state_tf = rotate(action_state_tf)
state_tf = flip(state_tf, 'vertical')
action_state_tf = flip(action_state_tf, 'vertical')
index = np.where(action_state_tf == 1)
action_tf = (int(index[0]), int(index[1]))
equalStates.append(state_tf)
equalActions.append(action_tf) return equalStates, equalActions

细心的读者可能会发问了:你这生成等价局面不去重的么?是的,不去重了。原因之一是如果要去重,那么要比对大量的np.array,实现起来较麻烦,可能会增加很多代码时间;原因之二是对重复的局面多次更新,只是不符合逻辑,但应该没有副作用:毕竟只要数据够多,最后Q表中的值都会收敛到一个值,而重复出现次数多的局面只是收敛得更快罢了。

步骤3:修改Agent()中的相关代码

需要修改方法 addNewState(self, env_, currentMove) 和方法 updateQtable(self, env_, currentMove, done_),整体代码如下:

import gym
import random
import time
import numpy as np # 查看所有已注册的环境
# from gym import envs
# print(envs.registry.all()) def str2tuple(string): # Input: '(1,1)'
string2list = list(string)
return ( int(string2list[1]), int(string2list[4]) ) # Output: (1,1) def rotate(array): # Input: np.array [[1,2,3],[4,5,6],[7,8,9]]
list_ = list(array)
list_[:] = map(list,zip(*list_[::-1]))
return np.array(list_) # Output: np.array [[7,4,1],[8,5,2],[9,6,3]] def flip(array_, direction): # Input: np.array [[1,2,3],[4,5,6],[7,8,9]]
array = array_.copy()
n = int(np.floor(len(array)/2))
if direction == 'vertical': # Output: np.array [[7,8,9],[4,5,6],[1,2,3]]
for i in range(n):
temp = array[i].copy()
array[i] = array[-i-1].copy()
array[-i-1] = temp
elif direction == 'horizon': # Output: np.array [[3,2,1],[6,5,4],[9,8,7]]
for i in range(n):
temp = array[:,i].copy()
array[:,i] = array[:,-i-1]
array[:,-i-1] = temp
return array class Game():
def __init__(self, env):
self.INTERVAL = 0 # 行动间隔
self.RENDER = False # 是否显示游戏过程
self.first = 'blue' if random.random() > 0.5 else 'red' # 随机先后手
self.currentMove = self.first
self.env = env
self.agent = Agent() def switchMove(self): # 切换行动玩家
move = self.currentMove
if move == 'blue': self.currentMove = 'red'
elif move == 'red': self.currentMove = 'blue' def newGame(self): # 新建游戏
self.first = 'blue' if random.random() > 0.5 else 'red'
self.currentMove = self.first
self.env.reset()
self.agent.reset() def run(self): # 玩一局游戏
self.env.reset() # 在第一次step前要先重置环境,不然会报错
while True:
print(f'--currentMove: {self.currentMove}--')
self.agent.updateQtable(self.env, self.currentMove, False) if self.currentMove == 'blue':
self.agent.lastState_blue = self.env.state.copy()
elif self.currentMove == 'red':
self.agent.lastState_red = self.agent.overTurn(self.env.state) # 红方视角需将状态翻转 action = self.agent.epsilon_greedy(self.env, self.currentMove)
if self.currentMove == 'blue':
self.agent.lastAction_blue = action['pos']
elif self.currentMove == 'red':
self.agent.lastAction_red = action['pos'] state, reward, done, info = self.env.step(action)
if done:
self.agent.lastReward_blue = reward
self.agent.lastReward_red = -1 * reward
self.agent.updateQtable(self.env, self.currentMove, True)
else:
if self.currentMove == 'blue':
self.agent.lastReward_blue = reward
elif self.currentMove == 'red':
self.agent.lastReward_red = -1 * reward if self.RENDER: self.env.render()
self.switchMove()
time.sleep(self.INTERVAL)
if done:
self.newGame()
if self.RENDER: self.env.render()
time.sleep(self.INTERVAL)
break class Agent():
def __init__(self):
self.Q_table = {}
self.EPSILON = 0.05
self.ALPHA = 0.5
self.GAMMA = 1 # 折扣因子
self.lastState_blue = None
self.lastAction_blue = None
self.lastReward_blue = None
self.lastState_red = None
self.lastAction_red = None
self.lastReward_red = None def reset(self):
self.lastState_blue = None
self.lastAction_blue = None
self.lastReward_blue = None
self.lastState_red = None
self.lastAction_red = None
self.lastReward_red = None def getEmptyPos(self, state): # 返回空位的坐标
action_space = []
for i, row in enumerate(state):
for j, one in enumerate(row):
if one == 0: action_space.append((i,j))
return action_space def randomAction(self, env_, mark): # 随机选择空格动作
actions = self.getEmptyPos(env_)
action_pos = random.choice(actions)
action = {'mark':mark, 'pos':action_pos}
return action def overTurn(self, state): # 翻转状态
state_ = state.copy()
for i, row in enumerate(state_):
for j, one in enumerate(row):
if one != 0: state_[i][j] *= -1
return state_ def genEqualStateAndAction(self, state_, action_): # Input: np.array, tuple(x,y)
state, action = state_.copy(), action_
equalStates, equalActions = [], [] # 原局面
equalStates.append(state)
equalActions.append(action) # 水平翻转
state_tf = state.copy()
action_state_tf = np.zeros(state.shape)
action_state_tf[action] = 1
state_tf = flip(state_tf, 'horizon')
action_state_tf = flip(action_state_tf, 'horizon')
index = np.where(action_state_tf == 1)
action_tf = (int(index[0]), int(index[1]))
equalStates.append(state_tf)
equalActions.append(action_tf) # 垂直翻转
state_tf = state.copy()
action_state_tf = np.zeros(state.shape)
action_state_tf[action] = 1
state_tf = flip(state_tf, 'vertical')
action_state_tf = flip(action_state_tf, 'vertical')
index = np.where(action_state_tf == 1)
action_tf = (int(index[0]), int(index[1]))
equalStates.append(state_tf)
equalActions.append(action_tf) # 旋转90°
state_tf = state.copy()
action_state_tf = np.zeros(state.shape)
action_state_tf[action] = 1
for i in range(1):
state_tf = rotate(state_tf)
action_state_tf = rotate(action_state_tf)
index = np.where(action_state_tf == 1)
action_tf = (int(index[0]), int(index[1]))
equalStates.append(state_tf)
equalActions.append(action_tf) # 旋转180°
state_tf = state.copy()
action_state_tf = np.zeros(state.shape)
action_state_tf[action] = 1
for i in range(2):
state_tf = rotate(state_tf)
action_state_tf = rotate(action_state_tf)
index = np.where(action_state_tf == 1)
action_tf = (int(index[0]), int(index[1]))
equalStates.append(state_tf)
equalActions.append(action_tf) # 旋转270°
state_tf = state.copy()
action_state_tf = np.zeros(state.shape)
action_state_tf[action] = 1
for i in range(3):
state_tf = rotate(state_tf)
action_state_tf = rotate(action_state_tf)
index = np.where(action_state_tf == 1)
action_tf = (int(index[0]), int(index[1]))
equalStates.append(state_tf)
equalActions.append(action_tf) # 旋转90° + 水平翻转
state_tf = state.copy()
action_state_tf = np.zeros(state.shape)
action_state_tf[action] = 1
for i in range(1):
state_tf = rotate(state_tf)
action_state_tf = rotate(action_state_tf)
state_tf = flip(state_tf, 'horizon')
action_state_tf = flip(action_state_tf, 'horizon')
index = np.where(action_state_tf == 1)
action_tf = (int(index[0]), int(index[1]))
equalStates.append(state_tf)
equalActions.append(action_tf) # 旋转90° + 垂直翻转
state_tf = state.copy()
action_state_tf = np.zeros(state.shape)
action_state_tf[action] = 1
for i in range(1):
state_tf = rotate(state_tf)
action_state_tf = rotate(action_state_tf)
state_tf = flip(state_tf, 'vertical')
action_state_tf = flip(action_state_tf, 'vertical')
index = np.where(action_state_tf == 1)
action_tf = (int(index[0]), int(index[1]))
equalStates.append(state_tf)
equalActions.append(action_tf) return equalStates, equalActions def addNewState(self, env_, currentMove): # 若当前状态不在Q表中,则新增状态
state = env_.state if currentMove == 'blue' else self.overTurn(env_.state) # 如果是红方行动则翻转状态
eqStates, eqActions = self.genEqualStateAndAction(state, (0,0)) for one in eqStates:
if str(one) not in self.Q_table:
self.Q_table[str(one)] = {}
actions = self.getEmptyPos(one)
for action in actions:
self.Q_table[str(one)][str(action)] = 0 def epsilon_greedy(self, env_, currentMove): # ε-贪心策略
state = env_.state if currentMove == 'blue' else self.overTurn(env_.state) # 如果是红方行动则翻转状态
Q_Sa = self.Q_table[str(state)]
maxAction, maxValue, otherAction = [], -100, []
for one in Q_Sa:
if Q_Sa[one] > maxValue:
maxValue = Q_Sa[one]
for one in Q_Sa:
if Q_Sa[one] == maxValue:
maxAction.append(str2tuple(one))
else:
otherAction.append(str2tuple(one)) try:
action_pos = random.choice(maxAction) if random.random() > self.EPSILON else random.choice(otherAction)
except: # 处理从空的otherAction中取值的情况
action_pos = random.choice(maxAction)
action = {'mark':currentMove, 'pos':action_pos}
return action def updateQtable(self, env_, currentMove, done_): judge = (currentMove == 'blue' and self.lastState_blue is None) or \
(currentMove == 'red' and self.lastState_red is None)
if judge: # 边界情况1:若agent无上一状态,说明是游戏中首次动作,那么只需要新增状态就好,无需更新Q值
self.addNewState(env_, currentMove)
return if done_: # 边界情况2:若当前状态S_是终止状态,则无需把S_添加至Q表格中,直接令maxQ_S_a = 0,并同时更新双方Q值
for one in ['blue', 'red']:
S = self.lastState_blue if one == 'blue' else self.lastState_red
a = self.lastAction_blue if one == 'blue' else self.lastAction_red
eqStates, eqActions = self.genEqualStateAndAction(S, a)
R = self.lastReward_blue if one == 'blue' else self.lastReward_red
# print('lastState S:\n', S)
# print('lastAction a: ', a)
# print('lastReward R: ', R)
# print('\n')
maxQ_S_a = 0
for S, a in zip(eqStates, eqActions):
self.Q_table[str(S)][str(a)] = (1 - self.ALPHA) * self.Q_table[str(S)][str(a)] \
+ self.ALPHA * (R + self.GAMMA * maxQ_S_a)
return # 其他情况下:Q表无当前状态则新增状态,否则直接更新Q值
self.addNewState(env_, currentMove)
S_ = env_.state if currentMove == 'blue' else self.overTurn(env_.state)
S = self.lastState_blue if currentMove == 'blue' else self.lastState_red
a = self.lastAction_blue if currentMove == 'blue' else self.lastAction_red
eqStates, eqActions = self.genEqualStateAndAction(S, a)
R = self.lastReward_blue if currentMove == 'blue' else self.lastReward_red
# print('lastState S:\n', S)
# print('State S_:\n', S_)
# print('lastAction a: ', a)
# print('lastReward R: ', R)
# print('\n')
Q_S_a = self.Q_table[str(S_)]
maxQ_S_a = -100
for one in Q_S_a:
if Q_S_a[one] > maxQ_S_a:
maxQ_S_a = Q_S_a[one]
for S, a in zip(eqStates, eqActions):
self.Q_table[str(S)][str(a)] = (1 - self.ALPHA) * self.Q_table[str(S)][str(a)] \
+ self.ALPHA * (R + self.GAMMA * maxQ_S_a) env = gym.make('TicTacToeEnv-v0')
game = Game(env)
for i in range(10000):
print('episode', i)
game.run()
Q_table = game.agent.Q_table

测试

经过了上述优化,agent能够在一轮对局中更新16个Q值,比起上一节 强化学习实战 | 表格型Q-Learning玩井字棋(二)开始训练! 中的更新2个Q值要多8倍,不妨就玩1万局游戏,看看是否能玩出之前玩8万局游戏的效果。

项目1:查看Q表格的状态数

一般般,仍然有状态没有覆盖到。

项目2:查看初始状态

先手开局:

这效果也太好了吧!不但有完美的对称,还有泾渭分明的胜负判断: 第一步走四边就稳了,走四角和走中间都是输面大。看来优化之后,Q值的整体方差这一块表现得非常好了。

再贴一个后手开局的情况:

项目3:测试代码时间

引入了更复杂的trick,确实是完美地争取到了一些收益,但玩一局游戏的时间一定是增加了,增加了多少呢?我们用上一节的老算法和本节的算法分别跑2000局游戏,记录一下时间(本人使用的CPU是:Intel(R) Core(TM) i7-9750H)。

双向更新+等价局面同步更新:

双向更新:

增加了不到两倍的时间,换来了大约8倍的更新量提高,还降低了方差,看来这优化是赚的。

小结

拿着优化好的算法,心里也有了些底气,可以放心大胆地增加训练时间了。下一节,我们将用训练完全Q表,用pygame做一个拥有人机对阵,机机对战,作弊功能的井字棋游戏。还可以做一些对战的数据分析,比如AI内战的胜率多高?AI对阵随机策略的胜率多高?下节见!

强化学习实战 | 表格型Q-Learning玩井子棋(三)优化,优化的更多相关文章

  1. 强化学习实战 | 表格型Q-Learning玩井字棋(二)

    在 强化学习实战 | 表格型Q-Learning玩井字棋(一)中,我们构建了以Game() 和 Agent() 类为基础的框架,本篇我们要让agent不断对弈,维护Q表格,提升棋力.那么我们先来盘算一 ...

  2. 强化学习实战 | 表格型Q-Learning玩井字棋(四)游戏时间

    在 强化学习实战 | 表格型Q-Learning玩井字棋(三)优化,优化 中,我们经过优化和训练,得到了一个还不错的Q表格,这一节我们将用pygame实现一个有人机对战,机机对战和作弊功能的井字棋游戏 ...

  3. 强化学习实战 | 表格型Q-Learning玩井字棋(一)

    在 强化学习实战 | 自定义Gym环境之井子棋 中,我们构建了一个井字棋环境,并进行了测试.接下来我们可以使用各种强化学习方法训练agent出棋,其中比较简单的是Q学习,Q即Q(S, a),是状态动作 ...

  4. 强化学习系列之:Deep Q Network (DQN)

    文章目录 [隐藏] 1. 强化学习和深度学习结合 2. Deep Q Network (DQN) 算法 3. 后续发展 3.1 Double DQN 3.2 Prioritized Replay 3. ...

  5. 强化学习实战 | 自定义Gym环境之井字棋

    在文章 强化学习实战 | 自定义Gym环境 中 ,我们了解了一个简单的环境应该如何定义,并使用 print 简单地呈现了环境.在本文中,我们将学习自定义一个稍微复杂一点的环境--井字棋.回想一下井字棋 ...

  6. 强化学习实战 | 自定义Gym环境之扫雷

    开始之前 先考虑几个问题: Q1:如何展开无雷区? Q2:如何计算格子的提示数? Q3:如何表示扫雷游戏的状态? A1:可以使用递归函数,或是堆栈. A2:一般的做法是,需要打开某格子时,再去统计周围 ...

  7. 强化学习实战 | 自定义Gym环境

    新手的第一个强化学习示例一般都从Open Gym开始.在这些示例中,我们不断地向环境施加动作,并得到观测和奖励,这也是Gym Env的基本用法: state, reward, done, info = ...

  8. 强化学习实战 | 自定义gym环境之显示字符串

    如果想用强化学习去实现扫雷.2048这种带有数字提示信息的游戏,自然是希望自定义 gym 环境时能把字符显示出来.上网查了很久,没有找到gym自带的图形工具Viewer可以显示字符串的信息,反而是通过 ...

  9. 深度强化学习:入门(Deep Reinforcement Learning: Scratching the surface)

    RL的方案 两个主要对象:Agent和Environment Agent观察Environment,做出Action,这个Action会对Environment造成一定影响和改变,继而Agent会从新 ...

随机推荐

  1. Python基础(@property)

    class Point(object): # def get_score(self): # return self.score # def set_score(self,value): # if no ...

  2. scrapy获取汽车之家数据

    1.创建scrapy项目 >scrapy startproject scrapy_carhome 2.找到对应接口 3.创建爬虫文件 > cd scrapy_carhome\scrapy_ ...

  3. requests之代理的使用

    import requests # 访问url url = 'http://www.baidu.com/s?' # 请求头 headers = { 'User-Agent': 'Mozilla/5.0 ...

  4. python实现对象测量

    目录: 问题,轮廓找到了,如何去计算对象的弧长与面积(闭合),多边形拟合,几何矩的计算等 (一)对象的弧长与面积 (二)多边形拟合 (三)几何矩的计算 (四)获取图像的外接矩形boundingRect ...

  5. Maven中所用的Dependency查找方法

    用了Maven,所需的JAR包就不能再像往常一样,自己找到并下载下来,用IDE导进去就完事了,Maven用了一个项目依赖(Dependency)的概念,用俗话说,就是我的项目需要用你这个jar包,就称 ...

  6. [loj6518]序列

    参考ExtremeSpanningTrees,考虑优化整体二分时求$g_{i}\in \{w_{mid},w_{mid+1}\}$的最优解 首先题目有一个条件似乎没有写出来,是保证$l\le k\le ...

  7. 【Tool】IDEA功能--SVN和Git

    IDEA功能--SVN和Git 2019-11-08  21:12:22  by冲冲 1.IDEA的SVN (1)提交项目代码到SVN服务器 ① 指定不用上传的目录 ② 设置项目上传的路径 SVN服务 ...

  8. Oracle-like 多条件过滤以及and or用法

    1.select * from  file  where DOC_SUBJECT  not like '%测试%' and (DOC_STATUS like '待审' or DOC_STATUS li ...

  9. 学习java的第七天

    一.今日收获 1.看完全学习手册上java关键字与标识符两节 2.了解了java的关键字与标识符 二.今日难题 1.基本都理解 三.明日目标 1.继续看完全学习手册上的内容 2.加油!

  10. 日常Java 2021/10/19

    Java集合框架 Java 集合框架主要包括两种类型的容器,一种是集合(Collection),存储一个元素集合,另一种是图(Map),存储键/值对映射. Collection接口又有3种子类型,Li ...