在强化学习(十三) 策略梯度(Policy Gradient)中,我们讲到了基于策略(Policy Based)的强化学习方法的基本思路,并讨论了蒙特卡罗策略梯度reinforce算法。但是由于该算法需要完整的状态序列,同时单独对策略函数进行迭代更新,不太容易收敛。

    在本篇我们讨论策略(Policy Based)和价值(Value Based)相结合的方法:Actor-Critic算法。

    本文主要参考了Sutton的强化学习书第13章和UCL强化学习讲义的第7讲。

1. Actor-Critic算法简介

    Actor-Critic从名字上看包括两部分,演员(Actor)和评价者(Critic)。其中Actor使用我们上一节讲到的策略函数,负责生成动作(Action)并和环境交互。而Critic使用我们之前讲到了的价值函数,负责评估Actor的表现,并指导Actor下一阶段的动作。

    回想我们上一篇的策略梯度,策略函数就是我们的Actor,但是那里是没有Critic的,我们当时使用了蒙特卡罗法来计算每一步的价值部分替代了Critic的功能,但是场景比较受限。因此现在我们使用类似DQN中用的价值函数来替代蒙特卡罗法,作为一个比较通用的Critic。

    也就是说在Actor-Critic算法中,我们需要做两组近似,第一组是策略函数的近似:$$\pi_{\theta}(s,a) = P(a|s,\theta)\approx  \pi(a|s)$$

    第二组是价值函数的近似,对于状态价值和动作价值函数分别是:$$\hat{v}(s, w) \approx v_{\pi}(s)$$$$\hat{q}(s,a,w) \approx q_{\pi}(s,a)$$

    对于我们上一节讲到的蒙特卡罗策略梯度reinforce算法,我们需要进行改造才能变成Actor-Critic算法。

    首先,在蒙特卡罗策略梯度reinforce算法中,我们的策略的参数更新公式是:$$\theta = \theta + \alpha \nabla_{\theta}log \pi_{\theta}(s_t,a_t)  v_t$$

    梯度更新部分中,$\nabla_{\theta}log \pi_{\theta}(s_t,a_t) $是我们的分值函数,不用动,要变成Actor的话改动的是$v_t$,这块不能再使用蒙特卡罗法来得到,而应该从Critic得到。

    而对于Critic来说,这块是新的,不过我们完全可以参考之前DQN的做法,即用一个Q网络来做为Critic, 这个Q网络的输入可以是状态,而输出是每个动作的价值或者最优动作的价值。

    现在我们汇总来说,就是Critic通过Q网络计算状态的最优价值$v_t$, 而Actor利用$v_t$这个最优价值迭代更新策略函数的参数$\theta$,进而选择动作,并得到反馈和新的状态,Critic使用反馈和新的状态更新Q网络参数$w$, 在后面Critic会使用新的网络参数$w$来帮Actor计算状态的最优价值$v_t$。

2. Actor-Critic算法可选形式

    在上一节我们已经对Actor-Critic算法的流程做了一个初步的总结,不过有一个可以注意的点就是,我们对于Critic评估的点选择是和上一篇策略梯度一样的状态价值$v_t$,实际上,我们还可以选择很多其他的指标来做为Critic的评估点。而目前可以使用的Actor-Critic评估点主要有:

    a) 基于状态价值:这是我们上一节使用的评估点,这样Actor的策略函数参数更新的法公式是:$$\theta = \theta + \alpha \nabla_{\theta}log \pi_{\theta}(s_t,a_t)  V(s,w)$$

    b) 基于动作价值:在DQN中,我们一般使用的都是动作价值函数Q来做价值评估,这样Actor的策略函数参数更新的法公式是:$$\theta = \theta + \alpha \nabla_{\theta}log \pi_{\theta}(s_t,a_t)  Q(s,a,w)$$

    c) 基于TD误差:在强化学习(五)用时序差分法(TD)求解中,我们讲到了TD误差,它的表达式是$\delta(t) = R_{t+1} + \gamma V(S_{t+1}) -V(S_t)$或者$\delta(t) = R_{t+1} + \gamma Q(S_{t+1},A_{t+1} ) -Q(S_t,A_t)$, 这样Actor的策略函数参数更新的法公式是:$$\theta = \theta + \alpha \nabla_{\theta}log \pi_{\theta}(s_t,a_t)\delta(t)$$

     d) 基于优势函数:在强化学习(十二) Dueling DQN中,我们讲到过优势函数A的定义:$A(S,A,w,\beta) = Q(S,A, w, \alpha, \beta) - V(S,w,\alpha) $,即动作价值函数和状态价值函数的差值。这样Actor的策略函数参数更新的法公式是:$$\theta = \theta + \alpha \nabla_{\theta}log \pi_{\theta}(s_t,a_t)A(S,A,w,\beta)$$

    e) 基于TD($\lambda$)误差:一般都是基于后向TD($\lambda$)误差, 在强化学习(五)用时序差分法(TD)求解中也有讲到,是TD误差和效用迹E的乘积。这样Actor的策略函数参数更新的法公式是:$$\theta = \theta + \alpha \nabla_{\theta}log \pi_{\theta}(s_t,a_t)\delta(t)E_(t)$$

    对于Critic本身的模型参数$w$,一般都是使用均方误差损失函数来做做迭代更新,类似之前DQN系列中所讲的迭代方法. 如果我们使用的是最简单的线性Q函数,比如$Q(s,a ,w) = \phi(s,a)^Tw$,则Critic本身的模型参数$w$的更新公式可以表示为:$$\delta = R_{t+1} + \gamma Q(S_{t+1},A_{t+1} ) -Q(S_t,A_t)$$$$w = w+ \beta\delta\phi(s,a)$$

    通过对均方误差损失函数求导可以很容易的得到上式。当然实际应用中,我们一般不使用线性Q函数,而使用神经网络表示状态和Q值的关系。

3. Actor-Critic算法流程

    这里给一个Actor-Critic算法的流程总结,评估点基于TD误差,Critic使用神经网络来计算TD误差并更新网络参数,Actor也使用神经网络来更新网络参数  

    算法输入:迭代轮数$T$,状态特征维度$n$, 动作集$A$, 步长$\alpha,\beta$,衰减因子$\gamma$, 探索率$\epsilon$, Critic网络结构和Actor网络结构。

    输出:Actor 网络参数$\theta$, Critic网络参数$w$

    1. 随机初始化所有的状态和动作对应的价值$Q$.  随机初始化Critic网络的所有参数$w$。随机初始化Actor网络的所有参数$\theta$。

    2. for i from 1 to T,进行迭代。

      a) 初始化S为当前状态序列的第一个状态, 拿到其特征向量$\phi(S)$

      b) 在Actor网络中使用$\phi(S)$作为输入,输出动作$A$,基于动作$A$得到新的状态$S'$,反馈$R$。

      c) 在Critic网络中分别使用$\phi(S), \phi(S‘’)$作为输入,得到Q值输出$V(S), V(S’)$

      d) 计算TD误差$\delta = R +\gamma V(S’) -V(S) $

      e) 使用均方差损失函数$\sum\limits(R +\gamma V(S’) -V(S,w))^2$作Critic网络参数$w$的梯度更新

      f)  更新Actor网络参数$\theta$:$$\theta = \theta + \alpha \nabla_{\theta}log \pi_{\theta}(S_t,A)\delta$$

    对于Actor的分值函数$\nabla_{\theta}log \pi_{\theta}(S_t,A)$,可以选择softmax或者高斯分值函数。

    上述Actor-Critic算法已经是一个很好的算法框架,但是离实际应用还比较远。主要原因是这里有两个神经网络,都需要梯度更新,而且互相依赖。但是了解这个算法过程后,其他基于Actor-Critic的算法就好理解了。

4. Actor-Critic算法实例

    下面我们用一个具体的例子来演示上面的Actor-Critic算法。仍然使用了OpenAI Gym中的CartPole-v0游戏来作为我们算法应用。CartPole-v0游戏的介绍参见这里。它比较简单,基本要求就是控制下面的cart移动使连接在上面的pole保持垂直不倒。这个任务只有两个离散动作,要么向左用力,要么向右用力。而state状态就是这个cart的位置和速度, pole的角度和角速度,4维的特征。坚持到200分的奖励则为过关。

    算法流程可以参考上面的第三节,这里的分值函数我们使用的是softmax函数,和上一片的类似。完整的代码参见我的Github:https://github.com/ljpzzz/machinelearning/blob/master/reinforcement-learning/actor_critic.py

    代码主要分为两部分,第一部分是Actor,第二部分是Critic。对于Actor部分,大家可以和上一篇策略梯度的代码对比,改动并不大,主要区别在于梯度更新部分,策略梯度使用是蒙特卡罗法计算出的价值$v(t)$,则我们的actor使用的是TD误差。

    在策略梯度部分,对应的位置如下:

  1. self.loss = tf.reduce_mean(self.neg_log_prob * self.tf_vt) # reward guided loss

    而我们的Actor对应的位置的代码是:

  1. self.exp = tf.reduce_mean(self.neg_log_prob * self.td_error)

    此处要注意的是,由于使用的是TD误差,而不是价值$v(t)$,此处需要最大化self.exp,而不是最小化它,这点和策略梯度不同。对应的Actor代码为:

  1. #这里需要最大化当前策略的价值,因此需要最大化self.exp,即最小化-self.exp
  2. self.train_op = tf.train.AdamOptimizer(LEARNING_RATE).minimize(-self.exp)

    除此之外,Actor部分的代码和策略梯度的代码区别并不大。

    对于Critic部分,我们使用了类似于DQN的三层神经网络。不过我们简化了这个网络的输出,只有一维输出值,而不是之前DQN使用的有多少个可选动作,就有多少维输出值。网络结构如下:

  1. def create_Q_network(self):
  2. # network weights
  3. W1q = self.weight_variable([self.state_dim, 20])
  4. b1q = self.bias_variable([20])
  5. W2q = self.weight_variable([20, ])
  6. b2q = self.bias_variable([])
  7. self.state_input = tf.placeholder(tf.float32, [1, self.state_dim], "state")
  8. # hidden layers
  9. h_layerq = tf.nn.relu(tf.matmul(self.state_input, W1q) + b1q)
  10. # Q Value layer
  11. self.Q_value = tf.matmul(h_layerq, W2q) + b2q

    和之前的DQN相比,这里还有一个区别就是我们的critic没有使用DQN的经验回放,只是使用了反馈和当前网络在下一个状态的输出来拟合当前状态。

          对于算法中Actor和Critic交互的逻辑,在main函数中:

  1. for step in range(STEP):
  2. action = actor.choose_action(state) # e-greedy action for train
  3. next_state,reward,done,_ = env.step(action)
  4. td_error = critic.train_Q_network(state, reward, next_state) # gradient = grad[r + gamma * V(s_) - V(s)]
  5. actor.learn(state, action, td_error) # true_gradient = grad[logPi(s,a) * td_error]
  6. state = next_state
  7. if done:
  8. break

    大家对照第三节的算法流程和代码应该可以比较容易理清这个过程。但是这个程序很难收敛。因此大家跑了后发现分数总是很低的话是可以理解的。我们需要优化这个问题。

5. Actor-Critic算法小结

    基本版的Actor-Critic算法虽然思路很好,但是由于难收敛的原因,还需要做改进。

    目前改进的比较好的有两个经典算法,一个是DDPG算法,使用了双Actor神经网络和双Critic神经网络的方法来改善收敛性。这个方法我们在从DQN到Nature DQN的过程中已经用过一次了。另一个是A3C算法,使用了多线程的方式,一个主线程负责更新Actor和Critic的参数,多个辅线程负责分别和环境交互,得到梯度更新值,汇总更新主线程的参数。而所有的辅线程会定期从主线程更新网络参数。这些辅线程起到了类似DQN中经验回放的作用,但是效果更好。

    在后面的文章中,我们会继续讨论DDPG和A3C。

(欢迎转载,转载请注明出处。欢迎沟通交流: liujianping-ok@163.com)      

    

强化学习(十四) Actor-Critic的更多相关文章

  1. 强化学习(十五) A3C

    在强化学习(十四) Actor-Critic中,我们讨论了Actor-Critic的算法流程,但是由于普通的Actor-Critic算法难以收敛,需要一些其他的优化.而Asynchronous Adv ...

  2. 强化学习(十六) 深度确定性策略梯度(DDPG)

    在强化学习(十五) A3C中,我们讨论了使用多线程的方法来解决Actor-Critic难收敛的问题,今天我们不使用多线程,而是使用和DDQN类似的方法:即经验回放和双网络的方法来改进Actor-Cri ...

  3. 强化学习(十九) AlphaGo Zero强化学习原理

    在强化学习(十八) 基于模拟的搜索与蒙特卡罗树搜索(MCTS)中,我们讨论了MCTS的原理和在棋类中的基本应用.这里我们在前一节MCTS的基础上,讨论下DeepMind的AlphaGo Zero强化学 ...

  4. 强化学习(十八) 基于模拟的搜索与蒙特卡罗树搜索(MCTS)

    在强化学习(十七) 基于模型的强化学习与Dyna算法框架中,我们讨论基于模型的强化学习方法的基本思路,以及集合基于模型与不基于模型的强化学习框架Dyna.本文我们讨论另一种非常流行的集合基于模型与不基 ...

  5. 强化学习(四)用蒙特卡罗法(MC)求解

    在强化学习(三)用动态规划(DP)求解中,我们讨论了用动态规划来求解强化学习预测问题和控制问题的方法.但是由于动态规划法需要在每一次回溯更新某一个状态的价值时,回溯到该状态的所有可能的后续状态.导致对 ...

  6. 【转载】 强化学习(四)用蒙特卡罗法(MC)求解

    原文地址: https://www.cnblogs.com/pinard/p/9492980.html ------------------------------------------------ ...

  7. 强化学习(十二) Dueling DQN

    在强化学习(十一) Prioritized Replay DQN中,我们讨论了对DQN的经验回放池按权重采样来优化DQN算法的方法,本文讨论另一种优化方法,Dueling DQN.本章内容主要参考了I ...

  8. 强化学习(四)—— DQN系列(DQN, Nature DQN, DDQN, Dueling DQN等)

    1 概述 在之前介绍的几种方法,我们对值函数一直有一个很大的限制,那就是它们需要用表格的形式表示.虽说表格形式对于求解有很大的帮助,但它也有自己的缺点.如果问题的状态和行动的空间非常大,使用表格表示难 ...

  9. Scala学习十四——模式匹配和样例类

    一.本章要点 match表达式是更好的switch,不会有意外调入下一个分支 如果没有模式能够匹配,会抛出MatchError,可以用case _模式避免 模式可以包含一个随意定义的条件,称做守卫 你 ...

随机推荐

  1. python笔记:#005#算数运算符

    算数运算符 计算机,顾名思义就是负责进行 数学计算 并且 存储计算结果 的电子设备 目标 算术运算符的基本使用 01. 算数运算符 算数运算符是 运算符的一种 是完成基本的算术运算使用的符号,用来处理 ...

  2. (function($){...})(jQuery) 含义 【转】

    经常用,今天总结一下,下文摘自某网友的总结: (function($){...})(jQuery)实际上是匿名函数,不懂得朋友可以继续往下看. 这里实际上是匿名函数 function(arg){... ...

  3. 简单记录numpy库的某些基本功能

    这里介绍python的一个库,numpy库,这个库是机器学习,数据分析最经常用到的库之一,也是利用python做数据必须用到的一个库,入门机器学习学的第一个python库就是它了. 先对其导入到pyt ...

  4. Java (三、数组)

    Java 数组 数组对于每一门编程语言来说都是重要的数据结构之一,当然不同语言对数组的实现及处理也不尽相同. Java 语言中提供的数组是用来存储固定大小的同类型元素. 声明数组变量 首先必须声明数组 ...

  5. 登录测试用例sql语句注入

    利用SQL注入漏洞登录后台的实现方法 作者: 字体:[增加 减小] 类型:转载 时间:2012-01-12我要评论 工作需要,得好好补习下关于WEB安全方面的相关知识,故撰此文,权当总结,别无它意.读 ...

  6. 机器学习,流式IoT和医疗设备互联

    欢迎大家前往腾讯云+社区,获取更多腾讯海量技术实践干货哦~ 让我们来看一下机器学习是如何应用于医护行业以及如何借助Apache Spark对患者的监控数据进行处理 现如今,IoT数据,实时流式数据分析 ...

  7. pycharm linux版快捷方式创建

    ****************************pycharm_linux安装and快捷方式创建******************1.下载好安装包之后解压:    tar -xfz 压缩包名 ...

  8. 【手记】ASP.NET提示“未能创建类型”处理

    我是在本机启动IIS Express调试一个ashx(一般处理程序)时遇到这个报错,网上的说法普遍有这么几种: 把bbb.ashx中的Class="aaa.bbb" 改为Class ...

  9. jQuery学习之旅 Item2 选择器【二】

    这里接着上一个Item1 把jQuery的选择器讲完.主要有:属性过滤器和子元素过滤器 点击"名称"会跳转到此方法的jQuery官方说明文档. 5. 属性过滤器 Attribute ...

  10. JavaScript Math 对象的常用方法

    JavaScript Math 对象 Math 对象 Math 对象用于执行数学任务. 使用 Math 的属性和方法的语法: var pi_value=Math.PI; var sqrt_value= ...