common/atari_wrappers.py模块代码如下;

import numpy as np
import os
os.environ.setdefault('PATH', '')
from collections import deque
import gym
from gym import spaces
import cv2
cv2.ocl.setUseOpenCL(False)
from .wrappers import TimeLimit class NoopResetEnv(gym.Wrapper):
def __init__(self, env, noop_max=30):
"""Sample initial states by taking random number of no-ops on reset.
No-op is assumed to be action 0.
"""
gym.Wrapper.__init__(self, env)
self.noop_max = noop_max
self.override_num_noops = None
self.noop_action = 0
assert env.unwrapped.get_action_meanings()[0] == 'NOOP' def reset(self, **kwargs):
""" Do no-op action for a number of steps in [1, noop_max]."""
self.env.reset(**kwargs)
if self.override_num_noops is not None:
noops = self.override_num_noops
else:
noops = self.unwrapped.np_random.randint(1, self.noop_max + 1) #pylint: disable=E1101
assert noops > 0
obs = None
for _ in range(noops):
obs, _, done, _ = self.env.step(self.noop_action)
if done:
obs = self.env.reset(**kwargs)
return obs def step(self, ac):
return self.env.step(ac) class FireResetEnv(gym.Wrapper):
def __init__(self, env):
"""Take action on reset for environments that are fixed until firing."""
gym.Wrapper.__init__(self, env)
assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
assert len(env.unwrapped.get_action_meanings()) >= 3 def reset(self, **kwargs):
self.env.reset(**kwargs)
obs, _, done, _ = self.env.step(1)
if done:
self.env.reset(**kwargs)
obs, _, done, _ = self.env.step(2)
if done:
self.env.reset(**kwargs)
return obs def step(self, ac):
return self.env.step(ac) class EpisodicLifeEnv(gym.Wrapper):
def __init__(self, env):
"""Make end-of-life == end-of-episode, but only reset on true game over.
Done by DeepMind for the DQN and co. since it helps value estimation.
"""
gym.Wrapper.__init__(self, env)
self.lives = 0
self.was_real_done = True def step(self, action):
obs, reward, done, info = self.env.step(action)
self.was_real_done = done
# check current lives, make loss of life terminal,
# then update lives to handle bonus lives
lives = self.env.unwrapped.ale.lives()
if lives < self.lives and lives > 0:
# for Qbert sometimes we stay in lives == 0 condition for a few frames
# so it's important to keep lives > 0, so that we only reset once
# the environment advertises done.
done = True
self.lives = lives
return obs, reward, done, info def reset(self, **kwargs):
"""Reset only when lives are exhausted.
This way all states are still reachable even though lives are episodic,
and the learner need not know about any of this behind-the-scenes.
"""
if self.was_real_done:
obs = self.env.reset(**kwargs)
else:
# no-op step to advance from terminal/lost life state
obs, _, _, _ = self.env.step(0)
self.lives = self.env.unwrapped.ale.lives()
return obs class MaxAndSkipEnv(gym.Wrapper):
def __init__(self, env, skip=4):
"""Return only every `skip`-th frame"""
gym.Wrapper.__init__(self, env)
# most recent raw observations (for max pooling across time steps)
self._obs_buffer = np.zeros((2,)+env.observation_space.shape, dtype=np.uint8)
self._skip = skip def step(self, action):
"""Repeat action, sum reward, and max over last observations."""
total_reward = 0.0
done = None
for i in range(self._skip):
obs, reward, done, info = self.env.step(action)
if i == self._skip - 2: self._obs_buffer[0] = obs
if i == self._skip - 1: self._obs_buffer[1] = obs
total_reward += reward
if done:
break
# Note that the observation on the done=True frame
# doesn't matter
max_frame = self._obs_buffer.max(axis=0) return max_frame, total_reward, done, info def reset(self, **kwargs):
return self.env.reset(**kwargs) class ClipRewardEnv(gym.RewardWrapper):
def __init__(self, env):
gym.RewardWrapper.__init__(self, env) def reward(self, reward):
"""Bin reward to {+1, 0, -1} by its sign."""
return np.sign(reward) class WarpFrame(gym.ObservationWrapper):
def __init__(self, env, width=84, height=84, grayscale=True, dict_space_key=None):
"""
Warp frames to 84x84 as done in the Nature paper and later work. If the environment uses dictionary observations, `dict_space_key` can be specified which indicates which
observation should be warped.
"""
super().__init__(env)
self._width = width
self._height = height
self._grayscale = grayscale
self._key = dict_space_key
if self._grayscale:
num_colors = 1
else:
num_colors = 3 new_space = gym.spaces.Box(
low=0,
high=255,
shape=(self._height, self._width, num_colors),
dtype=np.uint8,
)
if self._key is None:
original_space = self.observation_space
self.observation_space = new_space
else:
original_space = self.observation_space.spaces[self._key]
self.observation_space.spaces[self._key] = new_space
assert original_space.dtype == np.uint8 and len(original_space.shape) == 3 def observation(self, obs):
if self._key is None:
frame = obs
else:
frame = obs[self._key] if self._grayscale:
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
frame = cv2.resize(
frame, (self._width, self._height), interpolation=cv2.INTER_AREA
)
if self._grayscale:
frame = np.expand_dims(frame, -1) if self._key is None:
obs = frame
else:
obs = obs.copy()
obs[self._key] = frame
return obs class FrameStack(gym.Wrapper):
def __init__(self, env, k):
"""Stack k last frames. Returns lazy array, which is much more memory efficient. See Also
--------
baselines.common.atari_wrappers.LazyFrames
"""
gym.Wrapper.__init__(self, env)
self.k = k
self.frames = deque([], maxlen=k)
shp = env.observation_space.shape
self.observation_space = spaces.Box(low=0, high=255, shape=(shp[:-1] + (shp[-1] * k,)), dtype=env.observation_space.dtype) def reset(self):
ob = self.env.reset()
for _ in range(self.k):
self.frames.append(ob)
return self._get_ob() def step(self, action):
ob, reward, done, info = self.env.step(action)
self.frames.append(ob)
return self._get_ob(), reward, done, info def _get_ob(self):
assert len(self.frames) == self.k
return LazyFrames(list(self.frames)) class ScaledFloatFrame(gym.ObservationWrapper):
def __init__(self, env):
gym.ObservationWrapper.__init__(self, env)
self.observation_space = gym.spaces.Box(low=0, high=1, shape=env.observation_space.shape, dtype=np.float32) def observation(self, observation):
# careful! This undoes the memory optimization, use
# with smaller replay buffers only.
return np.array(observation).astype(np.float32) / 255.0 class LazyFrames(object):
def __init__(self, frames):
"""This object ensures that common frames between the observations are only stored once.
It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay
buffers. This object should only be converted to numpy array before being passed to the model. You'd not believe how complex the previous solution was."""
self._frames = frames
self._out = None def _force(self):
if self._out is None:
self._out = np.concatenate(self._frames, axis=-1)
self._frames = None
return self._out def __array__(self, dtype=None):
out = self._force()
if dtype is not None:
out = out.astype(dtype)
return out def __len__(self):
return len(self._force()) def __getitem__(self, i):
return self._force()[i] def count(self):
frames = self._force()
return frames.shape[frames.ndim - 1] def frame(self, i):
return self._force()[..., i] def make_atari(env_id, max_episode_steps=None):
env = gym.make(env_id)
assert 'NoFrameskip' in env.spec.id
env = NoopResetEnv(env, noop_max=30)
env = MaxAndSkipEnv(env, skip=4)
if max_episode_steps is not None:
env = TimeLimit(env, max_episode_steps=max_episode_steps)
return env def wrap_deepmind(env, episode_life=True, clip_rewards=True, frame_stack=False, scale=False):
"""Configure environment for DeepMind-style Atari.
"""
if episode_life:
env = EpisodicLifeEnv(env)
if 'FIRE' in env.unwrapped.get_action_meanings():
env = FireResetEnv(env)
env = WarpFrame(env)
if scale:
env = ScaledFloatFrame(env)
if clip_rewards:
env = ClipRewardEnv(env)
if frame_stack:
env = FrameStack(env, 4)
return env

该模块功能为对gym的env进行包装,也就是修改reset函数,step函数,同时也可以直接修改observation和reward的值,其实现原理为继承gym环境类的gym.Wrapper gym.RewardWrappergym.ObservationWrapper

其中,覆盖gym.Wrapper中的reset函数和step函数可以实现对env的操作包装,对gym.ObservationWrapper中的observation函数进行覆盖可以实现对环境返回的状态进行修改,对gym.RewardWrapper中的reward函数进行覆盖则可以实现对环境返回的奖励值进行修改的目的。

NoopResetEnv类只对reset函数进行覆盖,操作为在reset的时候随机进行一定数量的noop操作,目的是使游戏开始时的初始状态具有一定的变动。

FireResetEnv类只对reset函数进行覆盖,在reset操作的时候加入动作1和动作2,因为有的游戏需要有动作1作为开始键。

需要注意的是上面两个类在reset操作的时候不会出现游戏终止的情况,即dong=True,但是也写了done=True情况的处理操作,不过该部分可以忽略。

至于FireResetEnv为什么在reset的时候加入动作2,也就是env.step(2)还是没有找到解释的地方。

EpisodicLifeEnv类对step函数和reset函数都进行了覆盖,也就是考虑了游戏中有多条命的情况,如果是总游戏没有结束但是一个游戏live结束也对其进行reset操作,不过如果是总的游戏结束则调用包装的内层类reset,如果是lives减一但是不等于0的情况则进行0动作的操作,也就是fire操作,即 env.step(0)。在step函数中判断是总游戏回合结束还是丢失了一个游戏生命。

MaxAndSkipEnv类,进行重复动作并对得到的最近的两个observation进行取最大值,如默认的skip输入值为4,则进行4次的动作重复,也就是MaxAndSkipEnv类对象接受到一个动作的step则会调用内层包装的env进行4次相同action的step。

由于调用内层的env进行skip数量的step会得到4个observation,对最新得到的两个observation进行取max操作后作为最终的observation返回。

对于这个里面的step操作有一些个人的看法,感觉这个官方给出的设计有些瑕疵,自己的修改如下:

    def step(self, action):
"""Repeat action, sum reward, and max over last observations."""
total_reward = 0.0
done = None
for i in range(self._skip):
obs, reward, done, info = self.env.step(action)
self._obs_buffer[i%2] = obs
total_reward += reward
if done:
break
# Note that the observation on the done=True frame
# doesn't matter
max_frame = self._obs_buffer.max(axis=0) return max_frame, total_reward, done, info

个人的修改主要是考虑了当done=True时还没有对_obs_buffer进行填充的情况,其实这个修改对性能是没有影响的,这样改并不会提升性能,这个只不过可能有些强迫症类型的修改了。同时我们需要注意在step操作中调用内层env进行的skip次的step所得的奖励进行的加和。

类ClipRewardEnv,对step返回的reward进行了处理,也就是说如果有其他的Env类对ClipRewardEnv进行包装,那么ClipRewardEnv向上返回的reward只会为-1,0, +1 。

类WarpFrame(gym.ObservationWrapper)对所有内层的Env的observation进行处理,进行rgb转为灰度图后进行resize操作。需要注意的是rgb图转为灰度图后为保持ndim维度不变会在-1维度上进行扩充。

该类中需要注意的一个地方:

If the environment uses dictionary observations, `dict_space_key` can be specified which indicates which
observation should be warped.

dict_space_key可能是给某种环境类型使用的,该种类型的环境返回的observation不是np.array类型而是dict类型,如果需要获得observation中的np.array就需要调用dict_space_key,即observation[dict_space_key]。

把rgb图片转为gray灰度图后维度变少了,需要在-1维度上进行维度扩展。

        if self._grayscale:
frame = np.expand_dims(frame, -1)

类FrameStack(gym.Wrapper) , 将多个图片在-1维度上进行拼接,在reset函数时将内部传入进行包装的env的reset后返回的observation进行k次copy,其中k个图片的存储使用队列形式  deque([], maxlen=k) 。

    def step(self, action):
ob, reward, done, info = self.env.step(action)
self.frames.append(ob)
return self._get_ob(), reward, done, info

每个step得到的observation都用队列存储,最后向上返回的observation则是使用self._get_ob()生成的。

    def _get_ob(self):
assert len(self.frames) == self.k
return LazyFrames(list(self.frames))

可以看到每个observation都是k个图片构成的队列生成的LazyFrames对象。

类class LazyFrames(object), 将需要拼接的图片队列转为np.array类型。

类ScaledFloatFrame(gym.ObservationWrapper),将np.uint8类型转为np.float32,范围由0~255转为0~1 。

------------------------------------------------

def make_atari(env_id, max_episode_steps=None):
env = gym.make(env_id)
assert 'NoFrameskip' in env.spec.id
env = NoopResetEnv(env, noop_max=30)
env = MaxAndSkipEnv(env, skip=4)
if max_episode_steps is not None:
env = TimeLimit(env, max_episode_steps=max_episode_steps)
return env def wrap_deepmind(env, episode_life=True, clip_rewards=True, frame_stack=False, scale=False):
"""Configure environment for DeepMind-style Atari.
"""
if episode_life:
env = EpisodicLifeEnv(env)
if 'FIRE' in env.unwrapped.get_action_meanings():
env = FireResetEnv(env)
env = WarpFrame(env)
if scale:
env = ScaledFloatFrame(env)
if clip_rewards:
env = ClipRewardEnv(env)
if frame_stack:
env = FrameStack(env, 4)
return env

上面这两个类非标选择不同的包装器对env进行包装,也就是deepmind类型的atari环境包装和非deepmind的atari环境包装,这两个用法暂时不是很明白。

=============================================

baselines算法库common/atari_wrappers.py模块分析的更多相关文章

  1. openstack 中 log模块分析

    1 . 所在模块,一般在openstack/common/log.py,其实最主要的还是调用了python中的logging模块: 入口函数在 def setup(product_name, vers ...

  2. 【Python】【Web.py】详细解读Python的web.py框架下的application.py模块

    详细解读Python的web.py框架下的application.py模块   这篇文章主要介绍了Python的web.py框架下的application.py模块,作者深入分析了web.py的源码, ...

  3. Python标准库笔记(9) — functools模块

    functools 作用于函数的函数 functools 模块提供用于调整或扩展函数和其他可调用对象的工具,而无需完全重写它们. 装饰器 partial 类是 functools 模块提供的主要工具, ...

  4. python标准库介绍——12 time 模块详解

    ==time 模块== ``time`` 模块提供了一些处理日期和一天内时间的函数. 它是建立在 C 运行时库的简单封装. 给定的日期和时间可以被表示为浮点型(从参考时间, 通常是 1970.1.1 ...

  5. mahout算法库(四)

    mahout算法库 分为三大块 1.聚类算法 2.协同过滤算法(一般用于推荐) 协同过滤算法也可以称为推荐算法!!! 3.分类算法 算法类 算法名 中文名 分类算法               Log ...

  6. scikit-learn 支持向量机算法库使用小结

    之前通过一个系列对支持向量机(以下简称SVM)算法的原理做了一个总结,本文从实践的角度对scikit-learn SVM算法库的使用做一个小结.scikit-learn SVM算法库封装了libsvm ...

  7. OpenRisc-43-or1200的IF模块分析

    引言 “喂饱饥饿的CPU”,是计算机体系结构设计者时刻要考虑的问题.要解决这个问题,方法大体可分为两部分,第一就是利用principle of locality而引进的cache技术,缩短取指时间,第 ...

  8. 【转】python模块分析之unittest测试(五)

    [转]python模块分析之unittest测试(五) 系列文章 python模块分析之random(一) python模块分析之hashlib加密(二) python模块分析之typing(三) p ...

  9. 【转】python模块分析之hashlib加密(二)

    [转]python模块分析之hashlib加密(二) hashlib模块是用来对字符串进行hash加密的模块,明文与密文是一一对应不变的关系:用于注册.登录时用户名.密码等加密使用.一.函数分析:1. ...

  10. 【转】python之random模块分析(一)

    [转]python之random模块分析(一) random是python产生伪随机数的模块,随机种子默认为系统时钟.下面分析模块中的方法: 1.random.randint(start,stop): ...

随机推荐

  1. 在线HMAC加密工具

    在线HMAC加密工具提供一站式服务,支持MD5至SHA512.RIPEMD160及SM3等多种哈希算法,用户可便捷选择算法并生成安全的HMAC散列值,确保消息完整性与验证来源.适用于开发调试.网络安全 ...

  2. Hibernate Validator 校验注解

    Hibernate Validator 校验注解/** * 认识一些校验注解Hibernate Validator * * @NotNull 值不能为空 * @Null 值必须为空 * @Patter ...

  3. CloseableHttpClient设置超时时间demo 未设置默认是2分钟

    # CloseableHttpClient设置超时时间demo 未设置默认是2分钟 import org.apache.http.HttpHeaders; import org.apache.http ...

  4. 高级前端开发需要知道的 25 个 JavaScript 单行代码

    1. 不使用临时变量来交换变量的值 例如我们想要将 a 于 b 的值交换 let a = 1, b = 2; // 交换值 [a, b] = [b, a]; // 结果: a = 2, b = 1 这 ...

  5. 制作Jdk镜像

    本文介绍用Dockerfile的方式构建Jdk镜像,请保证安装了Docker环境. 首先创建/opt/jdk目录,后续步骤都在该目录下进行操作. 准备好Jdk安装文件,放到/opt/jdk目录下. 编 ...

  6. Flash驱动控制--芯片擦除(SPI协议)

    摘要: 本篇博客具体包括SPI协议的基本原理.模式选择以及时序逻辑要求,采用FPGA(EPCE4),通过SPI通信协议,对flash(W25Q16BV)存储的固化程序进行芯片擦除操作. 关键词:SPI ...

  7. [翻译].NET 8 的原生AOT及高性能Web开发中的应用[附性能测试结果]

    原文: [A Dive into .Net 8 Native AOT and Efficient Web Development] 作者: [sharmila subbiah] 引言 随着 .NET ...

  8. Apache Kylin(三)Kylin上手

    Kylin 上手 根据Kylin 官方给出的测试数据,我们实际操作一下 Kylin. 1. 导入 Hive 数据 首先创建一个project,在界面左上角有个"Add Project&quo ...

  9. STM32 学习:IAP 简单的IAP例子

    --- title: mcu-stm32-IAP-1-sample date: 2020-05-27 18:21:53 categories: tags: - iap - stm32 --- 章节概述 ...

  10. C#中?.、??、?:、及?等符号用途

    1.可空类型修饰符(?)   众所周知,在C#中引用类型可以使用一个null引用来表示一个不存在的值,比如 string str = null 是正确的: 但是值类型却不能为空,比如 int k = ...