https://github.com/zle1992/Reinforcement_Learning_Game

DeepQNetwork.py
 import numpy as np
import tensorflow as tf
from abc import ABCMeta, abstractmethod
np.random.seed(1)
tf.set_random_seed(1) import logging # 引入logging模块
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s') # logging.basicConfig函数对日志的输出格式及方式做相关配置
# 由于日志基本配置中级别设置为DEBUG,所以一下打印信息将会全部显示在控制台上 tfconfig = tf.ConfigProto()
tfconfig.gpu_options.allow_growth = True
session = tf.Session(config=tfconfig) class DeepQNetwork(object):
__metaclass__ = ABCMeta
"""docstring for DeepQNetwork"""
def __init__(self,
n_actions,
n_features,
learning_rate,
reward_decay,
e_greedy,
replace_target_iter,
memory_size,
e_greedy_increment,
output_graph,
log_dir,
):
super(DeepQNetwork, self).__init__() self.n_actions = n_actions
self.n_features = n_features
self.learning_rate=learning_rate
self.gamma=reward_decay
self.epsilon_max=e_greedy
self.replace_target_iter=replace_target_iter
self.memory_size=memory_size
self.epsilon_increment=e_greedy_increment
self.output_graph=output_graph
self.lr =learning_rate
# total learning step
self.learn_step_counter = 0
self.log_dir = log_dir self.s = tf.placeholder(tf.float32,[None]+self.n_features,name='s')
self.s_next = tf.placeholder(tf.float32,[None]+self.n_features,name='s_next') self.r = tf.placeholder(tf.float32,[None,],name='r')
self.a = tf.placeholder(tf.int32,[None,],name='a') self.q_eval = self._build_q_net(self.s, scope='eval_net', trainable=True)
self.q_next = self._build_q_net(self.s_next, scope='target_net', trainable=False) with tf.variable_scope('q_target'):
self.q_target = self.r + self.gamma * tf.reduce_max(self.q_next, axis=1, name='Qmax_s_') # shape=(None, )
with tf.variable_scope('q_eval'):
a_indices = tf.stack([tf.range(tf.shape(self.a)[0], dtype=tf.int32), self.a], axis=1)
self.q_eval_wrt_a = tf.gather_nd(params=self.q_eval, indices=a_indices) # shape=(None, )
with tf.variable_scope('loss'):
self.loss = tf.reduce_mean(tf.squared_difference(self.q_target, self.q_eval_wrt_a, name='TD_error'))
with tf.variable_scope('train'):
self._train_op = tf.train.RMSPropOptimizer(self.lr).minimize(self.loss) t_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='target_net')
e_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='eval_net') with tf.variable_scope("hard_replacement"):
self.target_replace_op=[tf.assign(t,e) for t,e in zip(t_params,e_params)] self.sess = tf.Session()
if self.output_graph:
tf.summary.FileWriter(self.log_dir,self.sess.graph) self.sess.run(tf.global_variables_initializer()) self.cost_his =[] @abstractmethod
def _build_q_net(self,x,scope,trainable):
raise NotImplementedError def learn(self,data): # check to replace target parameters
if self.learn_step_counter % self.replace_target_iter == 0:
self.sess.run(self.target_replace_op)
print('\ntarget_params_replaced\n') batch_memory_s = data['s'],
batch_memory_a = data['a'],
batch_memory_r = data['r'],
batch_memory_s_ = data['s_'],
_, cost = self.sess.run(
[self._train_op, self.loss],
feed_dict={
self.s: batch_memory_s,
self.a: batch_memory_a,
self.r: batch_memory_r,
self.s_next: batch_memory_s_,
})
self.cost_his.append(cost) # increasing epsilon
self.epsilon_max = self.epsilon_max + self.epsilon_increment if self.epsilon_max < self.epsilon_max else self.epsilon_max
self.learn_step_counter += 1 def choose_action(self,s):
s = s[np.newaxis,:]
aa = np.random.uniform()
#print("epsilon_max",self.epsilon_max)
if aa < self.epsilon_max:
action_value = self.sess.run(self.q_eval,feed_dict={self.s:s})
action = np.argmax(action_value)
else:
action = np.random.randint(0,self.n_actions)
return action
Memory.py
 import numpy as np
np.random.seed(1)
class Memory(object):
"""docstring for Memory"""
def __init__(self,
n_actions,
n_features,
memory_size):
super(Memory, self).__init__()
self.memory_size = memory_size
self.cnt =0 self.s = np.zeros([memory_size]+n_features)
self.a = np.zeros([memory_size,])
self.r = np.zeros([memory_size,])
self.s_ = np.zeros([memory_size]+n_features) def store_transition(self,s, a, r, s_):
#logging.info('store_transition')
index = self.cnt % self.memory_size
self.s[index] = s
self.a[index] = a
self.r[index] = r
self.s_[index] =s_
self.cnt+=1 def sample(self,n):
#logging.info('sample')
#assert self.cnt>=self.memory_size,'Memory has not been fulfilled'
N = min(self.memory_size,self.cnt)
indices = np.random.choice(N,size=n)
d ={}
d['s'] = self.s[indices][0]
d['s_'] = self.s_[indices][0]
d['r'] = self.r[indices][0]
d['a'] = self.a[indices][0]
return d

主函数

 import gym
import numpy as np
import tensorflow as tf from Memory import Memory
from DeepQNetwork import DeepQNetwork np.random.seed(1)
tf.set_random_seed(1) import logging # 引入logging模块
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s') # logging.basicConfig函数对日志的输出格式及方式做相关配置
# 由于日志基本配置中级别设置为DEBUG,所以一下打印信息将会全部显示在控制台上 tfconfig = tf.ConfigProto()
tfconfig.gpu_options.allow_growth = True
session = tf.Session(config=tfconfig) class DeepQNetwork4CartPole(DeepQNetwork):
"""docstring for ClassName"""
def __init__(self, **kwargs):
super(DeepQNetwork4CartPole, self).__init__(**kwargs) def _build_q_net(self,x,scope,trainable):
w_initializer, b_initializer = tf.random_normal_initializer(0., 0.3), tf.constant_initializer(0.1) with tf.variable_scope(scope):
e1 = tf.layers.dense(inputs=x,
units=32,
bias_initializer = b_initializer,
kernel_initializer=w_initializer,
activation = tf.nn.relu,
trainable=trainable)
q = tf.layers.dense(inputs=e1,
units=self.n_actions,
bias_initializer = b_initializer,
kernel_initializer=w_initializer,
activation = tf.nn.sigmoid,
trainable=trainable) return q batch_size = 64 memory_size =2000
#env = gym.make('Breakout-v0') #离散
env = gym.make('CartPole-v0') #离散 n_features= list(env.observation_space.shape)
n_actions= env.action_space.n env = env.unwrapped def run(): RL = DeepQNetwork4CartPole(
n_actions=n_actions,
n_features=n_features,
learning_rate=0.01,
reward_decay=0.9,
e_greedy=0.9,
replace_target_iter=200,
memory_size=memory_size,
e_greedy_increment=None,
output_graph=True,
log_dir = 'log/DeepQNetwork4CartPole/',
) memory = Memory(n_actions,n_features,memory_size=memory_size) step = 0
ep_r = 0
for episode in range(2000):
# initial observation
observation = env.reset() while True: # RL choose action based on observation
action = RL.choose_action(observation)
# logging.debug('action')
# print(action)
# RL take action and get_collectiot next observation and reward
observation_, reward, done, info=env.step(action) # take a random action # the smaller theta and closer to center the better
x, x_dot, theta, theta_dot = observation_
r1 = (env.x_threshold - abs(x))/env.x_threshold - 0.8
r2 = (env.theta_threshold_radians - abs(theta))/env.theta_threshold_radians - 0.5
reward = r1 + r2 memory.store_transition(observation, action, reward, observation_) if (step > 200) and (step % 5 == 0): data = memory.sample(batch_size)
RL.learn(data)
#print('step:%d----reward:%f---action:%d'%(step,reward,action))
# swap observation
observation = observation_
ep_r += reward
# break while loop when end of this episode
if(episode>700):
env.render() # render on the screen
if done:
print('episode: ', episode,
'ep_r: ', round(ep_r, 2),
' epsilon: ', round(RL.epsilon_max, 2))
ep_r = 0 break
step += 1 # end of game
print('game over')
env.destroy() def main(): run() if __name__ == '__main__':
main()
#run2()

DeepNetwork---tensorflow实现的更多相关文章

  1. Tensorflow 官方版教程中文版

    2015年11月9日,Google发布人工智能系统TensorFlow并宣布开源,同日,极客学院组织在线TensorFlow中文文档翻译.一个月后,30章文档全部翻译校对完成,上线并提供电子书下载,该 ...

  2. tensorflow学习笔记二:入门基础

    TensorFlow用张量这种数据结构来表示所有的数据.用一阶张量来表示向量,如:v = [1.2, 2.3, 3.5] ,如二阶张量表示矩阵,如:m = [[1, 2, 3], [4, 5, 6], ...

  3. 用Tensorflow让神经网络自动创造音乐

    #————————————————————————本文禁止转载,禁止用于各类讲座及ppt中,违者必究————————————————————————# 前几天看到一个有意思的分享,大意是讲如何用Ten ...

  4. tensorflow 一些好的blog链接和tensorflow gpu版本安装

    pading :SAME,VALID 区别  http://blog.csdn.net/mao_xiao_feng/article/details/53444333 tensorflow实现的各种算法 ...

  5. tensorflow中的基本概念

    本文是在阅读官方文档后的一些个人理解. 官方文档地址:https://www.tensorflow.org/versions/r0.12/get_started/basic_usage.html#ba ...

  6. kubernetes&tensorflow

    谷歌内部--Borg Google Brain跑在数十万台机器上 谷歌电商商品分类深度学习模型跑在1000+台机器上 谷歌外部--Kubernetes(https://github.com/kuber ...

  7. tensorflow学习

    tensorflow安装时遇到gcc: error trying to exec 'as': execvp: No such file or directory. 截止到2016年11月13号,源码编 ...

  8. 【转】TensorFlow练习20: 使用深度学习破解字符验证码

    验证码是根据随机字符生成一幅图片,然后在图片中加入干扰象素,用户必须手动填入,防止有人利用机器人自动批量注册.灌水.发垃圾广告等等 . 验证码的作用是验证用户是真人还是机器人:设计理念是对人友好,对机 ...

  9. 【转】机器学习教程 十四-利用tensorflow做手写数字识别

    模式识别领域应用机器学习的场景非常多,手写识别就是其中一种,最简单的数字识别是一个多类分类问题,我们借这个多类分类问题来介绍一下google最新开源的tensorflow框架,后面深度学习的内容都会基 ...

  10. 【转】Ubuntu 16.04安装配置TensorFlow GPU版本

    之前摸爬滚打总是各种坑,今天参考这篇文章终于解决了,甚是鸡冻\(≧▽≦)/,电脑不知道怎么的,安装不了16.04,就安装15.10再升级到16.04 requirements: Ubuntu 16.0 ...

随机推荐

  1. [administrative][qemu][kvm] qemu使用--bridge-helper

    公司服务器,源码安装的 qemu-2.9.0. 不是yum装的. 问题1:非超级用户的kvm权限问题: 略,直接超级用户使用. 也许有用? https://access.redhat.com/docu ...

  2. Oracle shrink table

    shrink必须开启行迁移功能. alter table table_name enable row movement ; 在oracle中可以使用alter table table_name shr ...

  3. PSU/OPATCH/OJVM下载页面及安装方式(最实用版)

    中文版:数据库 PSU,SPU(CPU),Bundle Patches 和 Patchsets 补丁号码快速参考 (文档 ID 1922396.1) Download Reference for Or ...

  4. python 版Faster Rcnn

    直接按照官网https://github.com/rbgirshick/py-faster-rcnn上的教程对faster Rcnn进行编译的时候,会发有一些层由于cudnn版本的更新,会报错如下: ...

  5. axure rp pro 8.0 注册码

    激活码:(亲测可用) 用户名:aaa 注册码:2GQrt5XHYY7SBK/4b22Gm4Dh8alaR0/0k3gEN5h7FkVPIn8oG3uphlOeytIajxGU 用户名:axureuse ...

  6. Pycharm快捷键大全(windows + Mac)

    Windows快捷键 1.编辑 Ctrl + Space    基本的代码完成(类.方法.属性) Ctrl + Alt + Space  快速导入任意类 Ctrl + Shift + Enter    ...

  7. sap component 导航 链接

    1: 定义一个导航链接名字,这个名子如果在程序中遇到(该名字会在程序中使用),就会触发这样一个导航. 导航有两个view,一个原来的view,一个出发abc之后的target view,也就是目标视图 ...

  8. [django]主次表如何取出对方数据[主表obj.子表__set()]

    [sql]mysql管理手头手册,多对多sql逻辑 国家--城市例子 class Country(models.Model): name = models.CharField(max_length=3 ...

  9. emq数据库插件

  10. PHP实现装饰器

    参考:https://www.cnblogs.com/onephp/p/6108940.html ●装饰器模式(Decorator),可以动态地添加修改类的功能 ●一个类提供了一项功能,如果要在修改并 ...