https://github.com/zle1992/Reinforcement_Learning_Game

DeepQNetwork.py
 import numpy as np
import tensorflow as tf
from abc import ABCMeta, abstractmethod
np.random.seed(1)
tf.set_random_seed(1) import logging # 引入logging模块
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s') # logging.basicConfig函数对日志的输出格式及方式做相关配置
# 由于日志基本配置中级别设置为DEBUG,所以一下打印信息将会全部显示在控制台上 tfconfig = tf.ConfigProto()
tfconfig.gpu_options.allow_growth = True
session = tf.Session(config=tfconfig) class DeepQNetwork(object):
__metaclass__ = ABCMeta
"""docstring for DeepQNetwork"""
def __init__(self,
n_actions,
n_features,
learning_rate,
reward_decay,
e_greedy,
replace_target_iter,
memory_size,
e_greedy_increment,
output_graph,
log_dir,
):
super(DeepQNetwork, self).__init__() self.n_actions = n_actions
self.n_features = n_features
self.learning_rate=learning_rate
self.gamma=reward_decay
self.epsilon_max=e_greedy
self.replace_target_iter=replace_target_iter
self.memory_size=memory_size
self.epsilon_increment=e_greedy_increment
self.output_graph=output_graph
self.lr =learning_rate
# total learning step
self.learn_step_counter = 0
self.log_dir = log_dir self.s = tf.placeholder(tf.float32,[None]+self.n_features,name='s')
self.s_next = tf.placeholder(tf.float32,[None]+self.n_features,name='s_next') self.r = tf.placeholder(tf.float32,[None,],name='r')
self.a = tf.placeholder(tf.int32,[None,],name='a') self.q_eval = self._build_q_net(self.s, scope='eval_net', trainable=True)
self.q_next = self._build_q_net(self.s_next, scope='target_net', trainable=False) with tf.variable_scope('q_target'):
self.q_target = self.r + self.gamma * tf.reduce_max(self.q_next, axis=1, name='Qmax_s_') # shape=(None, )
with tf.variable_scope('q_eval'):
a_indices = tf.stack([tf.range(tf.shape(self.a)[0], dtype=tf.int32), self.a], axis=1)
self.q_eval_wrt_a = tf.gather_nd(params=self.q_eval, indices=a_indices) # shape=(None, )
with tf.variable_scope('loss'):
self.loss = tf.reduce_mean(tf.squared_difference(self.q_target, self.q_eval_wrt_a, name='TD_error'))
with tf.variable_scope('train'):
self._train_op = tf.train.RMSPropOptimizer(self.lr).minimize(self.loss) t_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='target_net')
e_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='eval_net') with tf.variable_scope("hard_replacement"):
self.target_replace_op=[tf.assign(t,e) for t,e in zip(t_params,e_params)] self.sess = tf.Session()
if self.output_graph:
tf.summary.FileWriter(self.log_dir,self.sess.graph) self.sess.run(tf.global_variables_initializer()) self.cost_his =[] @abstractmethod
def _build_q_net(self,x,scope,trainable):
raise NotImplementedError def learn(self,data): # check to replace target parameters
if self.learn_step_counter % self.replace_target_iter == 0:
self.sess.run(self.target_replace_op)
print('\ntarget_params_replaced\n') batch_memory_s = data['s'],
batch_memory_a = data['a'],
batch_memory_r = data['r'],
batch_memory_s_ = data['s_'],
_, cost = self.sess.run(
[self._train_op, self.loss],
feed_dict={
self.s: batch_memory_s,
self.a: batch_memory_a,
self.r: batch_memory_r,
self.s_next: batch_memory_s_,
})
self.cost_his.append(cost) # increasing epsilon
self.epsilon_max = self.epsilon_max + self.epsilon_increment if self.epsilon_max < self.epsilon_max else self.epsilon_max
self.learn_step_counter += 1 def choose_action(self,s):
s = s[np.newaxis,:]
aa = np.random.uniform()
#print("epsilon_max",self.epsilon_max)
if aa < self.epsilon_max:
action_value = self.sess.run(self.q_eval,feed_dict={self.s:s})
action = np.argmax(action_value)
else:
action = np.random.randint(0,self.n_actions)
return action
Memory.py
 import numpy as np
np.random.seed(1)
class Memory(object):
"""docstring for Memory"""
def __init__(self,
n_actions,
n_features,
memory_size):
super(Memory, self).__init__()
self.memory_size = memory_size
self.cnt =0 self.s = np.zeros([memory_size]+n_features)
self.a = np.zeros([memory_size,])
self.r = np.zeros([memory_size,])
self.s_ = np.zeros([memory_size]+n_features) def store_transition(self,s, a, r, s_):
#logging.info('store_transition')
index = self.cnt % self.memory_size
self.s[index] = s
self.a[index] = a
self.r[index] = r
self.s_[index] =s_
self.cnt+=1 def sample(self,n):
#logging.info('sample')
#assert self.cnt>=self.memory_size,'Memory has not been fulfilled'
N = min(self.memory_size,self.cnt)
indices = np.random.choice(N,size=n)
d ={}
d['s'] = self.s[indices][0]
d['s_'] = self.s_[indices][0]
d['r'] = self.r[indices][0]
d['a'] = self.a[indices][0]
return d

主函数

 import gym
import numpy as np
import tensorflow as tf from Memory import Memory
from DeepQNetwork import DeepQNetwork np.random.seed(1)
tf.set_random_seed(1) import logging # 引入logging模块
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s') # logging.basicConfig函数对日志的输出格式及方式做相关配置
# 由于日志基本配置中级别设置为DEBUG,所以一下打印信息将会全部显示在控制台上 tfconfig = tf.ConfigProto()
tfconfig.gpu_options.allow_growth = True
session = tf.Session(config=tfconfig) class DeepQNetwork4CartPole(DeepQNetwork):
"""docstring for ClassName"""
def __init__(self, **kwargs):
super(DeepQNetwork4CartPole, self).__init__(**kwargs) def _build_q_net(self,x,scope,trainable):
w_initializer, b_initializer = tf.random_normal_initializer(0., 0.3), tf.constant_initializer(0.1) with tf.variable_scope(scope):
e1 = tf.layers.dense(inputs=x,
units=32,
bias_initializer = b_initializer,
kernel_initializer=w_initializer,
activation = tf.nn.relu,
trainable=trainable)
q = tf.layers.dense(inputs=e1,
units=self.n_actions,
bias_initializer = b_initializer,
kernel_initializer=w_initializer,
activation = tf.nn.sigmoid,
trainable=trainable) return q batch_size = 64 memory_size =2000
#env = gym.make('Breakout-v0') #离散
env = gym.make('CartPole-v0') #离散 n_features= list(env.observation_space.shape)
n_actions= env.action_space.n env = env.unwrapped def run(): RL = DeepQNetwork4CartPole(
n_actions=n_actions,
n_features=n_features,
learning_rate=0.01,
reward_decay=0.9,
e_greedy=0.9,
replace_target_iter=200,
memory_size=memory_size,
e_greedy_increment=None,
output_graph=True,
log_dir = 'log/DeepQNetwork4CartPole/',
) memory = Memory(n_actions,n_features,memory_size=memory_size) step = 0
ep_r = 0
for episode in range(2000):
# initial observation
observation = env.reset() while True: # RL choose action based on observation
action = RL.choose_action(observation)
# logging.debug('action')
# print(action)
# RL take action and get_collectiot next observation and reward
observation_, reward, done, info=env.step(action) # take a random action # the smaller theta and closer to center the better
x, x_dot, theta, theta_dot = observation_
r1 = (env.x_threshold - abs(x))/env.x_threshold - 0.8
r2 = (env.theta_threshold_radians - abs(theta))/env.theta_threshold_radians - 0.5
reward = r1 + r2 memory.store_transition(observation, action, reward, observation_) if (step > 200) and (step % 5 == 0): data = memory.sample(batch_size)
RL.learn(data)
#print('step:%d----reward:%f---action:%d'%(step,reward,action))
# swap observation
observation = observation_
ep_r += reward
# break while loop when end of this episode
if(episode>700):
env.render() # render on the screen
if done:
print('episode: ', episode,
'ep_r: ', round(ep_r, 2),
' epsilon: ', round(RL.epsilon_max, 2))
ep_r = 0 break
step += 1 # end of game
print('game over')
env.destroy() def main(): run() if __name__ == '__main__':
main()
#run2()

DeepNetwork---tensorflow实现的更多相关文章

  1. Tensorflow 官方版教程中文版

    2015年11月9日,Google发布人工智能系统TensorFlow并宣布开源,同日,极客学院组织在线TensorFlow中文文档翻译.一个月后,30章文档全部翻译校对完成,上线并提供电子书下载,该 ...

  2. tensorflow学习笔记二:入门基础

    TensorFlow用张量这种数据结构来表示所有的数据.用一阶张量来表示向量,如:v = [1.2, 2.3, 3.5] ,如二阶张量表示矩阵,如:m = [[1, 2, 3], [4, 5, 6], ...

  3. 用Tensorflow让神经网络自动创造音乐

    #————————————————————————本文禁止转载,禁止用于各类讲座及ppt中,违者必究————————————————————————# 前几天看到一个有意思的分享,大意是讲如何用Ten ...

  4. tensorflow 一些好的blog链接和tensorflow gpu版本安装

    pading :SAME,VALID 区别  http://blog.csdn.net/mao_xiao_feng/article/details/53444333 tensorflow实现的各种算法 ...

  5. tensorflow中的基本概念

    本文是在阅读官方文档后的一些个人理解. 官方文档地址:https://www.tensorflow.org/versions/r0.12/get_started/basic_usage.html#ba ...

  6. kubernetes&tensorflow

    谷歌内部--Borg Google Brain跑在数十万台机器上 谷歌电商商品分类深度学习模型跑在1000+台机器上 谷歌外部--Kubernetes(https://github.com/kuber ...

  7. tensorflow学习

    tensorflow安装时遇到gcc: error trying to exec 'as': execvp: No such file or directory. 截止到2016年11月13号,源码编 ...

  8. 【转】TensorFlow练习20: 使用深度学习破解字符验证码

    验证码是根据随机字符生成一幅图片,然后在图片中加入干扰象素,用户必须手动填入,防止有人利用机器人自动批量注册.灌水.发垃圾广告等等 . 验证码的作用是验证用户是真人还是机器人:设计理念是对人友好,对机 ...

  9. 【转】机器学习教程 十四-利用tensorflow做手写数字识别

    模式识别领域应用机器学习的场景非常多,手写识别就是其中一种,最简单的数字识别是一个多类分类问题,我们借这个多类分类问题来介绍一下google最新开源的tensorflow框架,后面深度学习的内容都会基 ...

  10. 【转】Ubuntu 16.04安装配置TensorFlow GPU版本

    之前摸爬滚打总是各种坑,今天参考这篇文章终于解决了,甚是鸡冻\(≧▽≦)/,电脑不知道怎么的,安装不了16.04,就安装15.10再升级到16.04 requirements: Ubuntu 16.0 ...

随机推荐

  1. easyUI表格多表头实现

    项目中要实现表格多表头,结合网上的例子自己实现了一个,包含frozenColumns情况. 一,通过标签创建 效果: <table id="schoolGrid" class ...

  2. 半屏控制器,view: UIViewController+KNSemiModal

    半屏控制器,view:  UIViewController+KNSemiModal https://github.com/kentnguyen/KNSemiModalViewController

  3. Python中给List添加元素的4种方法

    https://blog.csdn.net/hanshanyeyu/article/details/78839266 List 是 Python 中常用的数据类型,它一个有序集合,即其中的元素始终保持 ...

  4. 图->存储结构->邻接多重表

    文字描述 邻接多重表是无向图的另一种链式存储结构. 虽然邻接表是无向图的一种很有效的存储结构,在邻接表中容易求得顶点和边的各种信息. 但是,在邻接表中每一条边(vi,vj)有两个结点,分别在第i个和第 ...

  5. CS0433: 类型“Microsoft.Reporting.WebForms.ReportViewer”同时存在于“c:/WINDOWS/assembly/GAC_MSIL/Microsoft.ReportViewer.WebForms/8.0.0.0_ 标签: cassembly

    CS0433: 类型“Microsoft.Reporting.WebForms.ReportViewer”同时存在于“c:/WINDOWS/assembly/GAC_MSIL/Microsoft.Re ...

  6. python pip install 报错TypeError: unsupported operand type(s) for -=: 'Retry' and 'int' Command "python setup.py egg_info" failed with error code 1 in

    pip install http://download.pytorch.org/whl/cu80/torch-0.2.0.post3-cp27-cp27mu-manylinux1_x86_64.whl ...

  7. 一行js弹窗代码就能设计漂亮的弹窗广告

    接到一个设计需求,要求xmyanke在网站右侧挂一个弹窗广告宣传最近的活动,找了半天都没看到合适的,自己鼓捣了一行js弹窗代码就能设计漂亮的弹窗广告,来瞧一下,欢迎拍砖提意见,js弹窗广告代码如下: ...

  8. 小程序-formdata传参

    项目背景,后端接口要求formData传参: 在util.js文件中封装转化函数,代码如下: const formatTime = date => { const year = date.get ...

  9. Java 基础 面向对象修饰符和自定义数据类型

    不同修饰符使用细节 常用来修饰类.方法.变量的修饰符如下: public 权限修饰符,公共访问, 类,方法,成员变量 protected 权限修饰符,受保护访问, 方法,成员变量 默认什么也不写 也是 ...

  10. Python Socket通信例子

    一.TCP 通信 服务端 #!/usr/bin/env python # -*- coding: utf-8 -*- # server_tcp.py import socket so = socket ...