DDPG DDPG介绍2

ddpg输出的不是行为的概率, 而是具体的行为, 用于连续动作 (continuous action) 的预测

公式推导 推导

代码实现的gym的pendulum游戏,这个游戏是连续动作的

pendulum环境介绍

代码实践

  1. """
  2. Deep Deterministic Policy Gradient (DDPG), Reinforcement Learning.
  3. DDPG is Actor Critic based algorithm.
  4. Pendulum example.
  5.  
  6. View more on my tutorial page: https://morvanzhou.github.io/tutorials/
  7.  
  8. Using:
  9. tensorflow 1.0
  10. gym 0.8.0
  11. """
  12.  
  13. import tensorflow as tf
  14. import numpy as np
  15. import gym
  16. import time
  17.  
  18. np.random.seed(1)
  19. tf.set_random_seed(1)
  20.  
  21. ##################### hyper parameters ####################
  22.  
  23. MAX_EPISODES = 200
  24. MAX_EP_STEPS = 200
  25. lr_a = 0.001 # learning rate for actor
  26. lr_c = 0.001 # learning rate for critic
  27. gamma = 0.9 # reward discount
  28. REPLACEMENT = [
  29. dict(name='soft', tau=0.01),
  30. dict(name='hard', rep_iter_a=600, rep_iter_c=500)
  31. ][0] # you can try different target replacement strategies
  32. MEMORY_CAPACITY = 10000
  33. BATCH_SIZE = 32
  34.  
  35. RENDER = True
  36. OUTPUT_GRAPH = True
  37. ENV_NAME = 'Pendulum-v0'
  38.  
  39. ############################### Actor ####################################
  40.  
  41. class Actor(object):
  42. def __init__(self, sess, action_dim, action_bound, learning_rate, replacement):
  43. self.sess = sess
  44. self.a_dim = action_dim
  45. self.action_bound = action_bound
  46. self.lr = learning_rate
  47. self.replacement = replacement
  48. self.t_replace_counter = 0
  49.  
  50. with tf.variable_scope('Actor'):
  51. # 这个网络用于及时更新参数
  52. # input s, output a
  53. self.a = self._build_net(S, scope='eval_net', trainable=True)
  54.  
  55. ##这个网络不及时更新参数, 用于预测action
  56. # input s_, output a, get a_ for critic
  57. self.a_ = self._build_net(S_, scope='target_net', trainable=False)
  58.  
  59. self.e_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='Actor/eval_net')
  60. self.t_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='Actor/target_net')
  61.  
  62. if self.replacement['name'] == 'hard':
  63. self.t_replace_counter = 0
  64. self.hard_replace = [tf.assign(t, e) for t, e in zip(self.t_params, self.e_params)]
  65. else:
  66. self.soft_replace = [tf.assign(t, (1 - self.replacement['tau']) * t + self.replacement['tau'] * e)
  67. for t, e in zip(self.t_params, self.e_params)]
  68.  
  69. def _build_net(self, s, scope, trainable):#根据state预测action的网络
  70. with tf.variable_scope(scope):
  71. init_w = tf.random_normal_initializer(0., 0.3)
  72. init_b = tf.constant_initializer(0.1)
  73. net = tf.layers.dense(s, 30, activation=tf.nn.relu,
  74. kernel_initializer=init_w, bias_initializer=init_b, name='l1',
  75. trainable=trainable)
  76. with tf.variable_scope('a'):
  77. actions = tf.layers.dense(net, self.a_dim, activation=tf.nn.tanh, kernel_initializer=init_w,
  78. bias_initializer=init_b, name='a', trainable=trainable)
  79. scaled_a = tf.multiply(actions, self.action_bound, name='scaled_a') # Scale output to -action_bound to action_bound
  80. return scaled_a
  81.  
  82. def learn(self, s): # batch update
  83. self.sess.run(self.train_op, feed_dict={S: s})
  84.  
  85. if self.replacement['name'] == 'soft':
  86. self.sess.run(self.soft_replace)
  87. else:
  88. if self.t_replace_counter % self.replacement['rep_iter_a'] == 0:
  89. self.sess.run(self.hard_replace)
  90. self.t_replace_counter += 1
  91.  
  92. def choose_action(self, s):
  93. s = s[np.newaxis, :] # single state
  94. return self.sess.run(self.a, feed_dict={S: s})[0] # single action
  95.  
  96. def add_grad_to_graph(self, a_grads):
  97. with tf.variable_scope('policy_grads'):
  98. # ys = policy;
  99. # xs = policy's parameters;
  100. # a_grads = the gradients of the policy to get more Q
  101. # tf.gradients will calculate dys/dxs with a initial gradients for ys, so this is dq/da * da/dparams
  102. self.policy_grads = tf.gradients(ys=self.a, xs=self.e_params, grad_ys=a_grads)
  103.  
  104. with tf.variable_scope('A_train'):
  105. opt = tf.train.AdamOptimizer(-self.lr) # (- learning rate) for ascent policy
  106. self.train_op = opt.apply_gradients(zip(self.policy_grads, self.e_params))#对eval_net的参数更新
  107.  
  108. ############################### Critic ####################################
  109.  
  110. class Critic(object):
  111. def __init__(self, sess, state_dim, action_dim, learning_rate, gamma, replacement, a, a_):
  112. self.sess = sess
  113. self.s_dim = state_dim
  114. self.a_dim = action_dim
  115. self.lr = learning_rate
  116. self.gamma = gamma
  117. self.replacement = replacement
  118.  
  119. with tf.variable_scope('Critic'):
  120. # Input (s, a), output q
  121. self.a = tf.stop_gradient(a) # stop critic update flows to actor
  122. # 这个网络用于及时更新参数
  123. self.q = self._build_net(S, self.a, 'eval_net', trainable=True)
  124.  
  125. # 这个网络不及时更新参数, 用于评价actor
  126. # Input (s_, a_), output q_ for q_target
  127. self.q_ = self._build_net(S_, a_, 'target_net', trainable=False) # target_q is based on a_ from Actor's target_net
  128.  
  129. self.e_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='Critic/eval_net')
  130. self.t_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='Critic/target_net')
  131.  
  132. with tf.variable_scope('target_q'):
  133. self.target_q = R + self.gamma * self.q_#target计算
  134.  
  135. with tf.variable_scope('TD_error'):
  136. self.loss = tf.reduce_mean(tf.squared_difference(self.target_q, self.q))#计算loss
  137.  
  138. with tf.variable_scope('C_train'):
  139. self.train_op = tf.train.AdamOptimizer(self.lr).minimize(self.loss)#训练
  140.  
  141. with tf.variable_scope('a_grad'):
  142. self.a_grads = tf.gradients(self.q, a)[0] # tensor of gradients of each sample (None, a_dim)
  143.  
  144. if self.replacement['name'] == 'hard':
  145. self.t_replace_counter = 0
  146. self.hard_replacement = [tf.assign(t, e) for t, e in zip(self.t_params, self.e_params)]
  147. else:
  148. self.soft_replacement = [tf.assign(t, (1 - self.replacement['tau']) * t + self.replacement['tau'] * e)
  149. for t, e in zip(self.t_params, self.e_params)]
  150.  
  151. def _build_net(self, s, a, scope, trainable):#Q网络,计算Q(s,a)
  152. with tf.variable_scope(scope):
  153. init_w = tf.random_normal_initializer(0., 0.1)
  154. init_b = tf.constant_initializer(0.1)
  155.  
  156. with tf.variable_scope('l1'):
  157. n_l1 = 30
  158. w1_s = tf.get_variable('w1_s', [self.s_dim, n_l1], initializer=init_w, trainable=trainable)
  159. w1_a = tf.get_variable('w1_a', [self.a_dim, n_l1], initializer=init_w, trainable=trainable)
  160. b1 = tf.get_variable('b1', [1, n_l1], initializer=init_b, trainable=trainable)
  161. net = tf.nn.relu(tf.matmul(s, w1_s) + tf.matmul(a, w1_a) + b1)
  162.  
  163. with tf.variable_scope('q'):
  164. q = tf.layers.dense(net, 1, kernel_initializer=init_w, bias_initializer=init_b, trainable=trainable) # Q(s,a)
  165. return q
  166.  
  167. def learn(self, s, a, r, s_):
  168. self.sess.run(self.train_op, feed_dict={S: s, self.a: a, R: r, S_: s_})
  169. if self.replacement['name'] == 'soft':
  170. self.sess.run(self.soft_replacement)
  171. else:
  172. if self.t_replace_counter % self.replacement['rep_iter_c'] == 0:
  173. self.sess.run(self.hard_replacement)
  174. self.t_replace_counter += 1
  175.  
  176. ##################### Memory ####################
  177.  
  178. class Memory(object):
  179. def __init__(self, capacity, dims):
  180. self.capacity = capacity
  181. self.data = np.zeros((capacity, dims))
  182. self.pointer = 0
  183.  
  184. def store_transition(self, s, a, r, s_):
  185. transition = np.hstack((s, a, [r], s_))
  186. index = self.pointer % self.capacity # replace the old memory with new memory
  187. self.data[index, :] = transition
  188. self.pointer += 1
  189.  
  190. def sample(self, n):
  191. assert self.pointer >= self.capacity, 'Memory has not been fulfilled'
  192. indices = np.random.choice(self.capacity, size=n)
  193. return self.data[indices, :]
  194.  
  195. import pdb; pdb.set_trace()
  196. env = gym.make(ENV_NAME)
  197. env = env.unwrapped
  198. env.seed(1)
  199.  
  200. state_dim = env.observation_space.shape[0]#
  201. action_dim = env.action_space.shape[0]#1 连续动作,一维
  202. action_bound = env.action_space.high#[2]
  203.  
  204. # all placeholder for tf
  205. with tf.name_scope('S'):
  206. S = tf.placeholder(tf.float32, shape=[None, state_dim], name='s')
  207. with tf.name_scope('R'):
  208. R = tf.placeholder(tf.float32, [None, 1], name='r')
  209. with tf.name_scope('S_'):
  210. S_ = tf.placeholder(tf.float32, shape=[None, state_dim], name='s_')
  211.  
  212. sess = tf.Session()
  213.  
  214. # Create actor and critic.
  215. # They are actually connected to each other, details can be seen in tensorboard or in this picture:
  216. actor = Actor(sess, action_dim, action_bound, lr_a, REPLACEMENT)
  217. critic = Critic(sess, state_dim, action_dim, lr_c, gamma, REPLACEMENT, actor.a, actor.a_)
  218. actor.add_grad_to_graph(critic.a_grads)# # 将 critic 产出的 dQ/da 加入到 Actor 的 Graph 中去
  219.  
  220. sess.run(tf.global_variables_initializer())
  221.  
  222. M = Memory(MEMORY_CAPACITY, dims=2 * state_dim + action_dim + 1)
  223.  
  224. if OUTPUT_GRAPH:
  225. tf.summary.FileWriter("logs/", sess.graph)
  226.  
  227. var = 3 # control exploration
  228.  
  229. t1 = time.time()
  230. for i in range(MAX_EPISODES):
  231. s = env.reset()
  232. ep_reward = 0
  233.  
  234. for j in range(MAX_EP_STEPS):
  235.  
  236. if RENDER:
  237. env.render()
  238.  
  239. # Add exploration noise
  240. a = actor.choose_action(s)
  241. a = np.clip(np.random.normal(a, var), -2, 2) # add randomness to action selection for exploration
  242. s_, r, done, info = env.step(a)
  243.  
  244. M.store_transition(s, a, r / 10, s_)
  245.  
  246. if M.pointer > MEMORY_CAPACITY:
  247. var *= .9995 # decay the action randomness
  248. b_M = M.sample(BATCH_SIZE)
  249. b_s = b_M[:, :state_dim]
  250. b_a = b_M[:, state_dim: state_dim + action_dim]
  251. b_r = b_M[:, -state_dim - 1: -state_dim]
  252. b_s_ = b_M[:, -state_dim:]
  253.  
  254. critic.learn(b_s, b_a, b_r, b_s_)
  255. actor.learn(b_s)
  256.  
  257. s = s_
  258. ep_reward += r
  259.  
  260. if j == MAX_EP_STEPS-1:
  261. print('Episode:', i, ' Reward: %i' % int(ep_reward), 'Explore: %.2f' % var, )
  262. if ep_reward > -300:
  263. RENDER = True
  264. break
  265.  
  266. print('Running time: ', time.time()-t1)

深度增强学习--DDPG的更多相关文章

  1. 深度增强学习--DPPO

    PPO DPPO介绍 PPO实现 代码DPPO

  2. 深度增强学习--A3C

    A3C 它会创建多个并行的环境, 让多个拥有副结构的 agent 同时在这些并行环境上更新主结构中的参数. 并行中的 agent 们互不干扰, 而主结构的参数更新受到副结构提交更新的不连续性干扰, 所 ...

  3. 深度增强学习--DQN的变形

    DQN的变形 double DQN prioritised replay dueling DQN

  4. 深度增强学习--Actor Critic

    Actor Critic value-based和policy-based的结合 实例代码 import sys import gym import pylab import numpy as np ...

  5. 深度增强学习--Policy Gradient

    前面都是value based的方法,现在看一种直接预测动作的方法 Policy Based Policy Gradient 一个介绍 karpathy的博客 一个推导 下面的例子实现的REINFOR ...

  6. 深度增强学习--Deep Q Network

    从这里开始换个游戏演示,cartpole游戏 Deep Q Network 实例代码 import sys import gym import pylab import random import n ...

  7. 常用增强学习实验环境 II (ViZDoom, Roboschool, TensorFlow Agents, ELF, Coach等) (转载)

    原文链接:http://blog.csdn.net/jinzhuojun/article/details/78508203 前段时间Nature上发表的升级版Alpha Go - AlphaGo Ze ...

  8. 马里奥AI实现方式探索 ——神经网络+增强学习

    [TOC] 马里奥AI实现方式探索 --神经网络+增强学习 儿时我们都曾有过一个经典游戏的体验,就是马里奥(顶蘑菇^v^),这次里约奥运会闭幕式,日本作为2020年东京奥运会的东道主,安倍最后也已经典 ...

  9. 增强学习 | AlphaGo背后的秘密

    "敢于尝试,才有突破" 2017年5月27日,当今世界排名第一的中国棋手柯洁与AlphaGo 2.0的三局对战落败.该事件标志着最新的人工智能技术在围棋竞技领域超越了人类智能,借此 ...

随机推荐

  1. Zookeeper之Curator(1)客户端对节点的一些监控事件的api使用

    <一>节点改变事件的监听 public class CauratorClientTest { //链接地址 private static String zkhost="172.1 ...

  2. 支持flv的播放神器

    h1:让浏览器支持flv去https://github.com/Bilibili/flv.js h2:让手机电脑都支持mp4使用: <link rel="stylesheet" ...

  3. 禁止网页右键和复制,ctrl+a都不行。取消页面默认事件【全】。

    document.oncontextmenu=new Function("event.returnValue=false");document.onselectstart=new ...

  4. Js文件中调用其它Js函数的方法

    在项目开发过程中,也许你会遇这样的情况.在某一Js文件中需要完成某一功能,但这一功能的大部分代码在另外一个Js文件中已经完成了,自己只需要调用这个方法再加上几句代码就可以实现所需的功能.我们知道,在h ...

  5. PHP文件包含小结

    协议 各种协议的使用有时是关键 file协议 file后面需是///,例如file:///d:/1.txt 也可以是file://e:/1.txt,如果是在当前盘则可以file:///1.txt 如果 ...

  6. Java 生产者消费者 & 例题

    Queue http://m635674608.iteye.com/blog/1739860 http://www.iteye.com/problems/84758 http://blog.csdn. ...

  7. 【混合背包】CDOJ1606 难喝的饮料

    #include<cstdio> #include<algorithm> using namespace std; int n,V,op[20010],c[20010],w[2 ...

  8. 【树状数组逆序对】USACO.2011JAN-Above the median

    [题意] 给出一串数字,问中位数大于等于X的连续子串有几个.(这里如果有偶数个数,定义为偏大的那一个而非中间取平均) [思路] 下面的数据规模也小于原题,所以要改成__int64才行.没找到测试数据, ...

  9. 【DFS】POJ3009-Curling 2.0

    [题目大意] 给出一张地图,一旦往一个方向前进就必须一直向前,直到一下情况发生:(1)碰到了block,则停在block前,该block消失:(2)冲出了场地外:(3)到达了终点.改变方向十次以上或者 ...

  10. [原创]SSH中HibernateTemplate与HibernateDaoSupport关系

    UserDaoImpl继承了HibernateDaoSupport类,在findAll() 方法里面调用了getHibernateTemplate(), 同时applicationContext.xm ...