baselines算法库baselines/common/input.py模块分析
baselines算法库baselines/common/input.py模块代码:
import numpy as np
import tensorflow as tf
from gym.spaces import Discrete, Box, MultiDiscrete def observation_placeholder(ob_space, batch_size=None, name='Ob'):
'''
Create placeholder to feed observations into of the size appropriate to the observation space Parameters:
---------- ob_space: gym.Space observation space batch_size: int size of the batch to be fed into input. Can be left None in most cases. name: str name of the placeholder Returns:
------- tensorflow placeholder tensor
''' assert isinstance(ob_space, Discrete) or isinstance(ob_space, Box) or isinstance(ob_space, MultiDiscrete), \
'Can only deal with Discrete and Box observation spaces for now' dtype = ob_space.dtype
if dtype == np.int8:
dtype = np.uint8 return tf.placeholder(shape=(batch_size,) + ob_space.shape, dtype=dtype, name=name) def observation_input(ob_space, batch_size=None, name='Ob'):
'''
Create placeholder to feed observations into of the size appropriate to the observation space, and add input
encoder of the appropriate type.
''' placeholder = observation_placeholder(ob_space, batch_size, name)
return placeholder, encode_observation(ob_space, placeholder) def encode_observation(ob_space, placeholder):
'''
Encode input in the way that is appropriate to the observation space Parameters:
---------- ob_space: gym.Space observation space placeholder: tf.placeholder observation input placeholder
'''
if isinstance(ob_space, Discrete):
return tf.to_float(tf.one_hot(placeholder, ob_space.n))
elif isinstance(ob_space, Box):
return tf.to_float(placeholder)
elif isinstance(ob_space, MultiDiscrete):
placeholder = tf.cast(placeholder, tf.int32)
one_hots = [tf.to_float(tf.one_hot(placeholder[..., i], ob_space.nvec[i])) for i in range(placeholder.shape[-1])]
return tf.concat(one_hots, axis=-1)
else:
raise NotImplementedError
可以看到input.py模块中一共有三个函数,其中只有一个函数对外提供服务,也就是 observation_input 。
可以看到observation_placeholder函数和encode_observation函数都已经被observation_input函数包装到了一起。
在observation_placeholder函数中根据传入的 env.observation_space变量即可生成对应shape的tf.placeholder变量:
return tf.placeholder(shape=(batch_size,) + ob_space.shape, dtype=dtype, name=name)
也就是说observation_input函数中的placeholder已经是tf.placeholder类型的了。
encoder_observation 函数根据gym.spaces.observation_space的类型对tf.placeholder进行reshape操作,而tf.placeholder已经是TensorFlow的tensor变量,因此这里对tf.placeholder的操作都是在图中的构建操作,属于TensorFlow的操作。
根据encoder_observation中的代码:
我们可以知道对placeholder的reshape操作主要是对gym.space.observation_space属于Discrete和MultiDiscrete类型进行的。
如果传入的gym.space.observation_space为Discrete类型则对其对应的placeholder进行 tf.one_hot 操作,即:
tf.one_hot(placeholder, ob_space.n)
如果传入的gym.space.observation_space为MultiDiscrete类型则对其对应的placeholder中的每个Discrete进行 tf.one_hot 操作然后在concat拼接,即:
one_hots = [tf.to_float(tf.one_hot(placeholder[..., i], ob_space.nvec[i])) for i in range(placeholder.shape[-1])]
return tf.concat(one_hots, axis=-1)
其中,如果gym.space.observation_space为Discrete,则其observation空间的大小为env.observation_space.n 。
其中,如果gym.space.observation_space为MultiDiscrete,则其包含的第i个Discrete对应的observation空间的大小为env.observation_space.nvec[i] 。
=======================================
observation_input函数返回的两个变量,其中placeholder是为了给网络feed数据的,而对placeholder进行reshape后的变量是为了方便后面构建神经网络的。
可以知道在gym中gym.space.MultiDiscrete主要是任天堂的游戏使用的,Nintendo Game Controller 。
import gym env_space=gym.spaces.MultiDiscrete([ 5, 2, 2 ]) for i in range(env_space.shape[-1]):
print(env_space.nvec[i])
可以看到MultiDiscrete中的每个Discrete的空间大小使用nvec对应的索引查询。
一般,gym的常见observation空间类型有:
gym.space.Box
gym.spaces.Discrete
gym.spaces.MultiDiscrete
gym.spaces.MultiBinary
gym.spaces.Tuple
gym.spaces.Dict
其中:
gym.spaces.Box
gym.spaces.Discrete
gym.spaces.MultiDiscrete
gym.spaces.MultiBinary
对应的observation均为np.array类型,在这里:
gym.spaces.Discrete
gym.spaces.MultiDiscrete
由于状态空间可以进行one_hot处理以便于后面的计算,因此在common/input.py模块中的obervation_input函数对这样的情况进行处理,对原始的placeholder进行one_hot处理。
由于gym.spaces.MultiBinary没法进行one_hot操作,而observation_space属于类型:gym.spaces.Tuple和gym.spaces.Dict本身已经在baselines库中的其他模块被处理,由于在baselines库中已经对gym的env的observation进行了包装处理,所以可以保证在env.step和env.reset后获得的observation一定是np.arrray类型的,也就是说gym.spaces.Dict和gym.spaces.Tuple类型已经被处理了,feed给神经网络的observation只能是:gym.spaces.Box、gym.spaces.Discrete、gym.spaces.MultiDiscrete、gym.spaces.MultiBinary类型。
查看了一下baselines库的源码对env.observation的处理的代码,该部分代码在common/vec_env/util.py代码中,在该代码中如果observation_space如果是 gym.spaces.Dict和gym.spaces.Tuple则均转为gym.spaces.Dict类型,其中gym.spaces.Tuple转为gym.spaces.Dict时key用从0开始的数字代替,也就是说baselines库中没有对gym.spaces.Dict和gym.spaces.Tuple类型的observation进行过多的处理,也就是说如果原生的env环境为gym.spaces.Dict和gym.spaces.Tuple则传给算法模块的env类型也只能是gym.spaces.Dict类型,这时如果对这样的observation生成的gym.spaces.Dict类型的observation进行placeholder操作则会报错:
或者说在deepq算法中baselines库只能接收的observation类型只可以为gym.spaces.Discrete、gym.spaces.Box、gym.spaces.MultiDiscrete 。
===============================================
baselines算法库baselines/common/input.py模块分析的更多相关文章
- 图像滤镜艺术---ZPhotoEngine超级算法库
原文:图像滤镜艺术---ZPhotoEngine超级算法库 一直以来,都有个想法,想要做一个属于自己的图像算法库,这个想法,在经过了几个月的努力之后,终于诞生了,这就是ZPhotoEngine算法库. ...
- snowland-smx密码算法库
snowland-smx密码算法库 一.snowland-smx密码算法库的介绍 snowland-smx是python实现的国密套件,对标python实现的gmssl,包含国密SM2,SM3,SM4 ...
- scikit-learn 支持向量机算法库使用小结
之前通过一个系列对支持向量机(以下简称SVM)算法的原理做了一个总结,本文从实践的角度对scikit-learn SVM算法库的使用做一个小结.scikit-learn SVM算法库封装了libsvm ...
- 【Python】【Web.py】详细解读Python的web.py框架下的application.py模块
详细解读Python的web.py框架下的application.py模块 这篇文章主要介绍了Python的web.py框架下的application.py模块,作者深入分析了web.py的源码, ...
- 第三百零六节,Django框架,models.py模块,数据库操作——创建表、数据类型、索引、admin后台,补充Django目录说明以及全局配置文件配置
Django框架,models.py模块,数据库操作——创建表.数据类型.索引.admin后台,补充Django目录说明以及全局配置文件配置 数据库配置 django默认支持sqlite,mysql, ...
- 使用织梦开源的分词算法库编写的YII获取分词扩展
在编辑文章中,很多时候都需要自动根据文章内容获取关键字的功能,因此,本文主要是说明如何在yii中使用织梦开源的分词算法编写一个独立的扩展,可以在不同的模块中使用,步骤如下: 1 到这里下载其他朋友整理 ...
- 四 Django框架,models.py模块,数据库操作——创建表、数据类型、索引、admin后台,补充Django目录说明以及全局配置文件配置
Django框架,models.py模块,数据库操作——创建表.数据类型.索引.admin后台,补充Django目录说明以及全局配置文件配置 数据库配置 django默认支持sqlite,mysql, ...
- mahout算法库(四)
mahout算法库 分为三大块 1.聚类算法 2.协同过滤算法(一般用于推荐) 协同过滤算法也可以称为推荐算法!!! 3.分类算法 算法类 算法名 中文名 分类算法 Log ...
- 操作MySQL-数据库的安装及Pycharm模块的导入
操作MySQL-数据库的安装及Pycharm模块的导入 1.基于pyCharm开发环境,在CMD控制台输入依次输入以下步骤: (1)pip3 install PyMySQL < 安装 PyMy ...
- Magicodes.Pay,打造开箱即用的统一支付库,已提供ABP模块封装
Magicodes.Pay,打造开箱即用的统一支付库,已提供ABP模块封装 简介 Magicodes.Pay,是心莱科技团队提供的统一支付库,相关库均使用.NET标准库编写,支持.NET Framew ...
随机推荐
- Selenium模块的使用(二)
selenium处理iframe - 如果定位的标签存在于iframe标签之中,则必须使用switch_to.frame(id) - 动作链(拖动):from selenium.webdriver i ...
- 如何生成war包
pom.xml <packaging>war</packaging> 引入tomcat <dependency> <groupId>org.spring ...
- flutter 调用环信sdk 实现即时通讯
首先下载依赖 导包 import 'package:im_flutter_sdk/im_flutter_sdk.dart';登录 import 'package:flutter/material.da ...
- Centos7部署FytSoa项目至Docker——第二步:安装Mysql、Redis
FytSoa项目地址:https://gitee.com/feiyit/FytSoaCms 部署完成地址:http://82.156.127.60:8001/ 先到腾讯云申请一年的云服务器,我买的是一 ...
- HttpServletRequest获取header参数 sign
HttpServletRequest获取header参数 sign //从请求头中获取参数 private static Map<String, String> getHeaders(Ht ...
- Python使用.NET开发的类库来提高你的程序执行效率
Python由于本身的特性原因,执行程序期间可能效率并不是很理想.在某些需要自己提高一些代码的执行效率的时候,可以考虑使用C#.C++.Rust等语言开发的库来提高python本身的执行效率.接下来, ...
- js-文件读写和上传下载的简单例子01
现下,网络越来越快,浏览器的功能和性能越来越好,所以很多时候,已经不需要一些复杂的框架来实现不是非常复杂的功能. 我们只有在以下情况才会考虑使用框架或者现成的第三方组件: 1.功能复杂,自己写没有必要 ...
- 前端 Array.sort() 源码学习
源码地址 V8源码Array 710行开始为sort()相关 Array.sort()方法是那种排序呢? 去看源码主要是源于这个问题 // In-place QuickSort algorithm. ...
- 【排行榜】Carla leaderboard 排行榜 运行与参与手把手教学
此分支主要供参与leaderboard排名使用,介绍如何构建队伍,提交自己代码,此部分较为简单,主要是基本教学与演示:后续可以参考更多的开源代码进行学习等. 基本参与此榜单的大多都是学校和实验室,还是 ...
- 使用VS Code 学习算法(第四版)
最近在学习算法(第四版),书中一直在使用命令行来执行Java程序,而使用Eclipse时,很难使用命令行,或者说我根本就不会用,于是就想研究一下使用VS Code来编写代码,使用命令行来执行程序.看了 ...