强化学习实战 | 表格型Q-Learning玩井子棋(三)优化,优化
在 强化学习实战 | 表格型Q-Learning玩井字棋(二)开始训练!中,我们让agent“简陋地”训练了起来,经过了耗费时间的10万局游戏过后,却效果平平,尤其是初始状态的数值表现和预期相差不小。我想主要原因就是没有采用等价局面同步更新的方法,导致数据利用率较低。等价局面有7个,分别是:旋转90°,旋转180°,旋转270°,水平翻转,垂直翻转,旋转90°+水平翻转,旋转90°+垂直翻转,如下图所示。另外,在生成等价局面的同时,也要生成等价的动作,这样才能实现完整的Q值更新。
步骤1:写旋转和翻转函数
- def rotate(array): # Input: np.array [[1,2,3],[4,5,6],[7,8,9]]
- list_ = list(array)
- list_[:] = map(list,zip(*list_[::-1]))
- return np.array(list_) # Output: np.array [[7,4,1],[8,5,2],[9,6,3]]
- def flip(array_, direction): # Input: np.array [[1,2,3],[4,5,6],[7,8,9]]
- array = array_.copy()
- n = int(np.floor(len(array)/2))
- if direction == 'vertical': # Output: np.array [[7,8,9],[4,5,6],[1,2,3]]
- for i in range(n):
- temp = array[i].copy()
- array[i] = array[-i-1].copy()
- array[-i-1] = temp
- elif direction == 'horizon': # Output: np.array [[3,2,1],[6,5,4],[9,8,7]]
- for i in range(n):
- temp = array[:,i].copy()
- array[:,i] = array[:,-i-1]
- array[:,-i-1] = temp
- return array
步骤2:写生成等价局面及等价动作的函数
函数名为 genEqualStateAndAction(state, action),定义在 Agent() 类中。
- def genEqualStateAndAction(self, state_, action_): # Input: np.array, tuple(x,y)
- state, action = state_.copy(), action_
- equalStates, equalActions = [], []
- # 原局面
- equalStates.append(state)
- equalActions.append(action)
- # 水平翻转
- state_tf = state.copy()
- action_state_tf = np.zeros(state.shape)
- action_state_tf[action] = 1
- state_tf = flip(state_tf, 'horizon')
- action_state_tf = flip(action_state_tf, 'horizon')
- index = np.where(action_state_tf == 1)
- action_tf = (int(index[0]), int(index[1]))
- equalStates.append(state_tf)
- equalActions.append(action_tf)
- # 垂直翻转
- state_tf = state.copy()
- action_state_tf = np.zeros(state.shape)
- action_state_tf[action] = 1
- state_tf = flip(state_tf, 'vertical')
- action_state_tf = flip(action_state_tf, 'vertical')
- index = np.where(action_state_tf == 1)
- action_tf = (int(index[0]), int(index[1]))
- equalStates.append(state_tf)
- equalActions.append(action_tf)
- # 旋转90°
- state_tf = state.copy()
- action_state_tf = np.zeros(state.shape)
- action_state_tf[action] = 1
- for i in range(1):
- state_tf = rotate(state_tf)
- action_state_tf = rotate(action_state_tf)
- index = np.where(action_state_tf == 1)
- action_tf = (int(index[0]), int(index[1]))
- equalStates.append(state_tf)
- equalActions.append(action_tf)
- # 旋转180°
- state_tf = state.copy()
- action_state_tf = np.zeros(state.shape)
- action_state_tf[action] = 1
- for i in range(2):
- state_tf = rotate(state_tf)
- action_state_tf = rotate(action_state_tf)
- index = np.where(action_state_tf == 1)
- action_tf = (int(index[0]), int(index[1]))
- equalStates.append(state_tf)
- equalActions.append(action_tf)
- # 旋转270°
- state_tf = state.copy()
- action_state_tf = np.zeros(state.shape)
- action_state_tf[action] = 1
- for i in range(3):
- state_tf = rotate(state_tf)
- action_state_tf = rotate(action_state_tf)
- index = np.where(action_state_tf == 1)
- action_tf = (int(index[0]), int(index[1]))
- equalStates.append(state_tf)
- equalActions.append(action_tf)
- # 旋转90° + 水平翻转
- state_tf = state.copy()
- action_state_tf = np.zeros(state.shape)
- action_state_tf[action] = 1
- for i in range(1):
- state_tf = rotate(state_tf)
- action_state_tf = rotate(action_state_tf)
- state_tf = flip(state_tf, 'horizon')
- action_state_tf = flip(action_state_tf, 'horizon')
- index = np.where(action_state_tf == 1)
- action_tf = (int(index[0]), int(index[1]))
- equalStates.append(state_tf)
- equalActions.append(action_tf)
- # 旋转90° + 垂直翻转
- state_tf = state.copy()
- action_state_tf = np.zeros(state.shape)
- action_state_tf[action] = 1
- for i in range(1):
- state_tf = rotate(state_tf)
- action_state_tf = rotate(action_state_tf)
- state_tf = flip(state_tf, 'vertical')
- action_state_tf = flip(action_state_tf, 'vertical')
- index = np.where(action_state_tf == 1)
- action_tf = (int(index[0]), int(index[1]))
- equalStates.append(state_tf)
- equalActions.append(action_tf)
- return equalStates, equalActions
细心的读者可能会发问了:你这生成等价局面不去重的么?是的,不去重了。原因之一是如果要去重,那么要比对大量的np.array,实现起来较麻烦,可能会增加很多代码时间;原因之二是对重复的局面多次更新,只是不符合逻辑,但应该没有副作用:毕竟只要数据够多,最后Q表中的值都会收敛到一个值,而重复出现次数多的局面只是收敛得更快罢了。
步骤3:修改Agent()中的相关代码
需要修改方法 addNewState(self, env_, currentMove) 和方法 updateQtable(self, env_, currentMove, done_),整体代码如下:


- import gym
- import random
- import time
- import numpy as np
- # 查看所有已注册的环境
- # from gym import envs
- # print(envs.registry.all())
- def str2tuple(string): # Input: '(1,1)'
- string2list = list(string)
- return ( int(string2list[1]), int(string2list[4]) ) # Output: (1,1)
- def rotate(array): # Input: np.array [[1,2,3],[4,5,6],[7,8,9]]
- list_ = list(array)
- list_[:] = map(list,zip(*list_[::-1]))
- return np.array(list_) # Output: np.array [[7,4,1],[8,5,2],[9,6,3]]
- def flip(array_, direction): # Input: np.array [[1,2,3],[4,5,6],[7,8,9]]
- array = array_.copy()
- n = int(np.floor(len(array)/2))
- if direction == 'vertical': # Output: np.array [[7,8,9],[4,5,6],[1,2,3]]
- for i in range(n):
- temp = array[i].copy()
- array[i] = array[-i-1].copy()
- array[-i-1] = temp
- elif direction == 'horizon': # Output: np.array [[3,2,1],[6,5,4],[9,8,7]]
- for i in range(n):
- temp = array[:,i].copy()
- array[:,i] = array[:,-i-1]
- array[:,-i-1] = temp
- return array
- class Game():
- def __init__(self, env):
- self.INTERVAL = 0 # 行动间隔
- self.RENDER = False # 是否显示游戏过程
- self.first = 'blue' if random.random() > 0.5 else 'red' # 随机先后手
- self.currentMove = self.first
- self.env = env
- self.agent = Agent()
- def switchMove(self): # 切换行动玩家
- move = self.currentMove
- if move == 'blue': self.currentMove = 'red'
- elif move == 'red': self.currentMove = 'blue'
- def newGame(self): # 新建游戏
- self.first = 'blue' if random.random() > 0.5 else 'red'
- self.currentMove = self.first
- self.env.reset()
- self.agent.reset()
- def run(self): # 玩一局游戏
- self.env.reset() # 在第一次step前要先重置环境,不然会报错
- while True:
- print(f'--currentMove: {self.currentMove}--')
- self.agent.updateQtable(self.env, self.currentMove, False)
- if self.currentMove == 'blue':
- self.agent.lastState_blue = self.env.state.copy()
- elif self.currentMove == 'red':
- self.agent.lastState_red = self.agent.overTurn(self.env.state) # 红方视角需将状态翻转
- action = self.agent.epsilon_greedy(self.env, self.currentMove)
- if self.currentMove == 'blue':
- self.agent.lastAction_blue = action['pos']
- elif self.currentMove == 'red':
- self.agent.lastAction_red = action['pos']
- state, reward, done, info = self.env.step(action)
- if done:
- self.agent.lastReward_blue = reward
- self.agent.lastReward_red = -1 * reward
- self.agent.updateQtable(self.env, self.currentMove, True)
- else:
- if self.currentMove == 'blue':
- self.agent.lastReward_blue = reward
- elif self.currentMove == 'red':
- self.agent.lastReward_red = -1 * reward
- if self.RENDER: self.env.render()
- self.switchMove()
- time.sleep(self.INTERVAL)
- if done:
- self.newGame()
- if self.RENDER: self.env.render()
- time.sleep(self.INTERVAL)
- break
- class Agent():
- def __init__(self):
- self.Q_table = {}
- self.EPSILON = 0.05
- self.ALPHA = 0.5
- self.GAMMA = 1 # 折扣因子
- self.lastState_blue = None
- self.lastAction_blue = None
- self.lastReward_blue = None
- self.lastState_red = None
- self.lastAction_red = None
- self.lastReward_red = None
- def reset(self):
- self.lastState_blue = None
- self.lastAction_blue = None
- self.lastReward_blue = None
- self.lastState_red = None
- self.lastAction_red = None
- self.lastReward_red = None
- def getEmptyPos(self, state): # 返回空位的坐标
- action_space = []
- for i, row in enumerate(state):
- for j, one in enumerate(row):
- if one == 0: action_space.append((i,j))
- return action_space
- def randomAction(self, env_, mark): # 随机选择空格动作
- actions = self.getEmptyPos(env_)
- action_pos = random.choice(actions)
- action = {'mark':mark, 'pos':action_pos}
- return action
- def overTurn(self, state): # 翻转状态
- state_ = state.copy()
- for i, row in enumerate(state_):
- for j, one in enumerate(row):
- if one != 0: state_[i][j] *= -1
- return state_
- def genEqualStateAndAction(self, state_, action_): # Input: np.array, tuple(x,y)
- state, action = state_.copy(), action_
- equalStates, equalActions = [], []
- # 原局面
- equalStates.append(state)
- equalActions.append(action)
- # 水平翻转
- state_tf = state.copy()
- action_state_tf = np.zeros(state.shape)
- action_state_tf[action] = 1
- state_tf = flip(state_tf, 'horizon')
- action_state_tf = flip(action_state_tf, 'horizon')
- index = np.where(action_state_tf == 1)
- action_tf = (int(index[0]), int(index[1]))
- equalStates.append(state_tf)
- equalActions.append(action_tf)
- # 垂直翻转
- state_tf = state.copy()
- action_state_tf = np.zeros(state.shape)
- action_state_tf[action] = 1
- state_tf = flip(state_tf, 'vertical')
- action_state_tf = flip(action_state_tf, 'vertical')
- index = np.where(action_state_tf == 1)
- action_tf = (int(index[0]), int(index[1]))
- equalStates.append(state_tf)
- equalActions.append(action_tf)
- # 旋转90°
- state_tf = state.copy()
- action_state_tf = np.zeros(state.shape)
- action_state_tf[action] = 1
- for i in range(1):
- state_tf = rotate(state_tf)
- action_state_tf = rotate(action_state_tf)
- index = np.where(action_state_tf == 1)
- action_tf = (int(index[0]), int(index[1]))
- equalStates.append(state_tf)
- equalActions.append(action_tf)
- # 旋转180°
- state_tf = state.copy()
- action_state_tf = np.zeros(state.shape)
- action_state_tf[action] = 1
- for i in range(2):
- state_tf = rotate(state_tf)
- action_state_tf = rotate(action_state_tf)
- index = np.where(action_state_tf == 1)
- action_tf = (int(index[0]), int(index[1]))
- equalStates.append(state_tf)
- equalActions.append(action_tf)
- # 旋转270°
- state_tf = state.copy()
- action_state_tf = np.zeros(state.shape)
- action_state_tf[action] = 1
- for i in range(3):
- state_tf = rotate(state_tf)
- action_state_tf = rotate(action_state_tf)
- index = np.where(action_state_tf == 1)
- action_tf = (int(index[0]), int(index[1]))
- equalStates.append(state_tf)
- equalActions.append(action_tf)
- # 旋转90° + 水平翻转
- state_tf = state.copy()
- action_state_tf = np.zeros(state.shape)
- action_state_tf[action] = 1
- for i in range(1):
- state_tf = rotate(state_tf)
- action_state_tf = rotate(action_state_tf)
- state_tf = flip(state_tf, 'horizon')
- action_state_tf = flip(action_state_tf, 'horizon')
- index = np.where(action_state_tf == 1)
- action_tf = (int(index[0]), int(index[1]))
- equalStates.append(state_tf)
- equalActions.append(action_tf)
- # 旋转90° + 垂直翻转
- state_tf = state.copy()
- action_state_tf = np.zeros(state.shape)
- action_state_tf[action] = 1
- for i in range(1):
- state_tf = rotate(state_tf)
- action_state_tf = rotate(action_state_tf)
- state_tf = flip(state_tf, 'vertical')
- action_state_tf = flip(action_state_tf, 'vertical')
- index = np.where(action_state_tf == 1)
- action_tf = (int(index[0]), int(index[1]))
- equalStates.append(state_tf)
- equalActions.append(action_tf)
- return equalStates, equalActions
- def addNewState(self, env_, currentMove): # 若当前状态不在Q表中,则新增状态
- state = env_.state if currentMove == 'blue' else self.overTurn(env_.state) # 如果是红方行动则翻转状态
- eqStates, eqActions = self.genEqualStateAndAction(state, (0,0))
- for one in eqStates:
- if str(one) not in self.Q_table:
- self.Q_table[str(one)] = {}
- actions = self.getEmptyPos(one)
- for action in actions:
- self.Q_table[str(one)][str(action)] = 0
- def epsilon_greedy(self, env_, currentMove): # ε-贪心策略
- state = env_.state if currentMove == 'blue' else self.overTurn(env_.state) # 如果是红方行动则翻转状态
- Q_Sa = self.Q_table[str(state)]
- maxAction, maxValue, otherAction = [], -100, []
- for one in Q_Sa:
- if Q_Sa[one] > maxValue:
- maxValue = Q_Sa[one]
- for one in Q_Sa:
- if Q_Sa[one] == maxValue:
- maxAction.append(str2tuple(one))
- else:
- otherAction.append(str2tuple(one))
- try:
- action_pos = random.choice(maxAction) if random.random() > self.EPSILON else random.choice(otherAction)
- except: # 处理从空的otherAction中取值的情况
- action_pos = random.choice(maxAction)
- action = {'mark':currentMove, 'pos':action_pos}
- return action
- def updateQtable(self, env_, currentMove, done_):
- judge = (currentMove == 'blue' and self.lastState_blue is None) or \
- (currentMove == 'red' and self.lastState_red is None)
- if judge: # 边界情况1:若agent无上一状态,说明是游戏中首次动作,那么只需要新增状态就好,无需更新Q值
- self.addNewState(env_, currentMove)
- return
- if done_: # 边界情况2:若当前状态S_是终止状态,则无需把S_添加至Q表格中,直接令maxQ_S_a = 0,并同时更新双方Q值
- for one in ['blue', 'red']:
- S = self.lastState_blue if one == 'blue' else self.lastState_red
- a = self.lastAction_blue if one == 'blue' else self.lastAction_red
- eqStates, eqActions = self.genEqualStateAndAction(S, a)
- R = self.lastReward_blue if one == 'blue' else self.lastReward_red
- # print('lastState S:\n', S)
- # print('lastAction a: ', a)
- # print('lastReward R: ', R)
- # print('\n')
- maxQ_S_a = 0
- for S, a in zip(eqStates, eqActions):
- self.Q_table[str(S)][str(a)] = (1 - self.ALPHA) * self.Q_table[str(S)][str(a)] \
- + self.ALPHA * (R + self.GAMMA * maxQ_S_a)
- return
- # 其他情况下:Q表无当前状态则新增状态,否则直接更新Q值
- self.addNewState(env_, currentMove)
- S_ = env_.state if currentMove == 'blue' else self.overTurn(env_.state)
- S = self.lastState_blue if currentMove == 'blue' else self.lastState_red
- a = self.lastAction_blue if currentMove == 'blue' else self.lastAction_red
- eqStates, eqActions = self.genEqualStateAndAction(S, a)
- R = self.lastReward_blue if currentMove == 'blue' else self.lastReward_red
- # print('lastState S:\n', S)
- # print('State S_:\n', S_)
- # print('lastAction a: ', a)
- # print('lastReward R: ', R)
- # print('\n')
- Q_S_a = self.Q_table[str(S_)]
- maxQ_S_a = -100
- for one in Q_S_a:
- if Q_S_a[one] > maxQ_S_a:
- maxQ_S_a = Q_S_a[one]
- for S, a in zip(eqStates, eqActions):
- self.Q_table[str(S)][str(a)] = (1 - self.ALPHA) * self.Q_table[str(S)][str(a)] \
- + self.ALPHA * (R + self.GAMMA * maxQ_S_a)
- env = gym.make('TicTacToeEnv-v0')
- game = Game(env)
- for i in range(10000):
- print('episode', i)
- game.run()
- Q_table = game.agent.Q_table
测试
经过了上述优化,agent能够在一轮对局中更新16个Q值,比起上一节 强化学习实战 | 表格型Q-Learning玩井字棋(二)开始训练! 中的更新2个Q值要多8倍,不妨就玩1万局游戏,看看是否能玩出之前玩8万局游戏的效果。
项目1:查看Q表格的状态数
一般般,仍然有状态没有覆盖到。
项目2:查看初始状态
先手开局:
这效果也太好了吧!不但有完美的对称,还有泾渭分明的胜负判断: 第一步走四边就稳了,走四角和走中间都是输面大。看来优化之后,Q值的整体方差这一块表现得非常好了。
再贴一个后手开局的情况:
项目3:测试代码时间
引入了更复杂的trick,确实是完美地争取到了一些收益,但玩一局游戏的时间一定是增加了,增加了多少呢?我们用上一节的老算法和本节的算法分别跑2000局游戏,记录一下时间(本人使用的CPU是:Intel(R) Core(TM) i7-9750H)。
双向更新+等价局面同步更新:
双向更新:
增加了不到两倍的时间,换来了大约8倍的更新量提高,还降低了方差,看来这优化是赚的。
小结
拿着优化好的算法,心里也有了些底气,可以放心大胆地增加训练时间了。下一节,我们将用训练完全Q表,用pygame做一个拥有人机对阵,机机对战,作弊功能的井字棋游戏。还可以做一些对战的数据分析,比如AI内战的胜率多高?AI对阵随机策略的胜率多高?下节见!
强化学习实战 | 表格型Q-Learning玩井子棋(三)优化,优化的更多相关文章
- 强化学习实战 | 表格型Q-Learning玩井字棋(二)
在 强化学习实战 | 表格型Q-Learning玩井字棋(一)中,我们构建了以Game() 和 Agent() 类为基础的框架,本篇我们要让agent不断对弈,维护Q表格,提升棋力.那么我们先来盘算一 ...
- 强化学习实战 | 表格型Q-Learning玩井字棋(四)游戏时间
在 强化学习实战 | 表格型Q-Learning玩井字棋(三)优化,优化 中,我们经过优化和训练,得到了一个还不错的Q表格,这一节我们将用pygame实现一个有人机对战,机机对战和作弊功能的井字棋游戏 ...
- 强化学习实战 | 表格型Q-Learning玩井字棋(一)
在 强化学习实战 | 自定义Gym环境之井子棋 中,我们构建了一个井字棋环境,并进行了测试.接下来我们可以使用各种强化学习方法训练agent出棋,其中比较简单的是Q学习,Q即Q(S, a),是状态动作 ...
- 强化学习系列之:Deep Q Network (DQN)
文章目录 [隐藏] 1. 强化学习和深度学习结合 2. Deep Q Network (DQN) 算法 3. 后续发展 3.1 Double DQN 3.2 Prioritized Replay 3. ...
- 强化学习实战 | 自定义Gym环境之井字棋
在文章 强化学习实战 | 自定义Gym环境 中 ,我们了解了一个简单的环境应该如何定义,并使用 print 简单地呈现了环境.在本文中,我们将学习自定义一个稍微复杂一点的环境--井字棋.回想一下井字棋 ...
- 强化学习实战 | 自定义Gym环境之扫雷
开始之前 先考虑几个问题: Q1:如何展开无雷区? Q2:如何计算格子的提示数? Q3:如何表示扫雷游戏的状态? A1:可以使用递归函数,或是堆栈. A2:一般的做法是,需要打开某格子时,再去统计周围 ...
- 强化学习实战 | 自定义Gym环境
新手的第一个强化学习示例一般都从Open Gym开始.在这些示例中,我们不断地向环境施加动作,并得到观测和奖励,这也是Gym Env的基本用法: state, reward, done, info = ...
- 强化学习实战 | 自定义gym环境之显示字符串
如果想用强化学习去实现扫雷.2048这种带有数字提示信息的游戏,自然是希望自定义 gym 环境时能把字符显示出来.上网查了很久,没有找到gym自带的图形工具Viewer可以显示字符串的信息,反而是通过 ...
- 深度强化学习:入门(Deep Reinforcement Learning: Scratching the surface)
RL的方案 两个主要对象:Agent和Environment Agent观察Environment,做出Action,这个Action会对Environment造成一定影响和改变,继而Agent会从新 ...
随机推荐
- Python 官方研讨会:彻底移除 GIL 真的可行么?
作者:Łukasz Langa 译者:豌豆花下猫,来源:Python猫 原文:https://lukasz.langa.pl/5d044f91-49c1-4170-aed1-62b6763e6ad0 ...
- C#生成新浪微博短网址 示例源码
using System; using System.Collections.Generic; using System.Linq; using System.Text; using DotN ...
- Go语言核心36讲(Go语言实战与应用五)--学习笔记
27 | 条件变量sync.Cond (上) 前导内容:条件变量与互斥锁 我们常常会把条件变量这个同步工具拿来与互斥锁一起讨论.实际上,条件变量是基于互斥锁的,它必须有互斥锁的支撑才能发挥作用. 条件 ...
- myeclipse自带tomcat
安装myeclipse自带的tomcat没有在myeclipse的安装目录下,是再myeclipse指定的工作空间下 的.metadata\.plugins\com.genuitec.eclipse. ...
- [hdu6987]Cycle Binary
定义$x$为$s$的周期,当且仅当$\forall 1\le i\le |s|-x,s_{i}=s_{i+x}$(字符串下标从1开始) 令$per(s)$为$s$的正周期构成的集合,$\min p ...
- [loj3525]喷泉公园
先将整张图$x$和$y$都缩小一半,即"道路"长度变为1,"长椅"变为放在格子中心 如果在没有长椅的限制下也无解(直接dfs即可判定),显然原问题也无解 否则 ...
- vue的常用指令
https://www.bootcdn.cn/ 前端资源库 <!-- 常用内置指令 v:text : 更新元素的 textContent v-html : 更新元素的 innerHTML v-i ...
- CF1540B Tree Array
先写一下自己想到的部分: 考虑枚举一个根. 计算一个点对出现的概率. 对于我这种期望概率基本不会的人,差点就把这题切了. 自己想到的部分都没有假. 问题在于: 如何计算一个点对出现的概率. 考虑和这两 ...
- Codeforces 571E - Geometric Progressions(数论+阿巴细节题)
Codeforces 题目传送门 & 洛谷题目传送门 u1s1 感觉此题思维难度不太大,不过大概是细节多得到了精神污染的地步所以才放到 D1E 的罢((( 首先我们对所有 \(a_i,b_i\ ...
- 如何反向推断基因型文件中的参考碱基(REF/ALT)?
目录 需求 解决 方法一 方法二 需求 客户随手丢来一个基因型文件,类似于hapmap格式,只是少了中间多余的那几列,像这种类hapmap格式文件,往往是芯片数据. 这样的数据因为缺乏等位基因:参考碱 ...