模块代码:

import multiprocessing as mp

import numpy as np
from .vec_env import VecEnv, CloudpickleWrapper, clear_mpi_env_vars def worker(remote, parent_remote, env_fn_wrappers):
def step_env(env, action):
ob, reward, done, info = env.step(action)
if done:
ob = env.reset()
return ob, reward, done, info parent_remote.close()
envs = [env_fn_wrapper() for env_fn_wrapper in env_fn_wrappers.x]
try:
while True:
cmd, data = remote.recv()
if cmd == 'step':
remote.send([step_env(env, action) for env, action in zip(envs, data)])
elif cmd == 'reset':
remote.send([env.reset() for env in envs])
elif cmd == 'render':
remote.send([env.render(mode='rgb_array') for env in envs])
elif cmd == 'close':
remote.close()
break
elif cmd == 'get_spaces_spec':
remote.send(CloudpickleWrapper((envs[0].observation_space, envs[0].action_space, envs[0].spec)))
else:
raise NotImplementedError
except KeyboardInterrupt:
print('SubprocVecEnv worker: got KeyboardInterrupt')
finally:
for env in envs:
env.close() class SubprocVecEnv(VecEnv):
"""
VecEnv that runs multiple environments in parallel in subproceses and communicates with them via pipes.
Recommended to use when num_envs > 1 and step() can be a bottleneck.
"""
def __init__(self, env_fns, spaces=None, context='spawn', in_series=1):
"""
Arguments: env_fns: iterable of callables - functions that create environments to run in subprocesses. Need to be cloud-pickleable
in_series: number of environments to run in series in a single process
(e.g. when len(env_fns) == 12 and in_series == 3, it will run 4 processes, each running 3 envs in series)
"""
self.waiting = False
self.closed = False
self.in_series = in_series
nenvs = len(env_fns)
assert nenvs % in_series == 0, "Number of envs must be divisible by number of envs to run in series"
self.nremotes = nenvs // in_series
env_fns = np.array_split(env_fns, self.nremotes)
ctx = mp.get_context(context)
self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(self.nremotes)])
self.ps = [ctx.Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))
for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]
for p in self.ps:
p.daemon = True # if the main process crashes, we should not cause things to hang
with clear_mpi_env_vars():
p.start()
for remote in self.work_remotes:
remote.close() self.remotes[0].send(('get_spaces_spec', None))
observation_space, action_space, self.spec = self.remotes[0].recv().x
self.viewer = None
VecEnv.__init__(self, nenvs, observation_space, action_space) def step_async(self, actions):
self._assert_not_closed()
actions = np.array_split(actions, self.nremotes)
for remote, action in zip(self.remotes, actions):
remote.send(('step', action))
self.waiting = True def step_wait(self):
self._assert_not_closed()
results = [remote.recv() for remote in self.remotes]
results = _flatten_list(results)
self.waiting = False
obs, rews, dones, infos = zip(*results)
return _flatten_obs(obs), np.stack(rews), np.stack(dones), infos def reset(self):
self._assert_not_closed()
for remote in self.remotes:
remote.send(('reset', None))
obs = [remote.recv() for remote in self.remotes]
obs = _flatten_list(obs)
return _flatten_obs(obs) def close_extras(self):
self.closed = True
if self.waiting:
for remote in self.remotes:
remote.recv()
for remote in self.remotes:
remote.send(('close', None))
for p in self.ps:
p.join() def get_images(self):
self._assert_not_closed()
for pipe in self.remotes:
pipe.send(('render', None))
imgs = [pipe.recv() for pipe in self.remotes]
imgs = _flatten_list(imgs)
return imgs def _assert_not_closed(self):
assert not self.closed, "Trying to operate on a SubprocVecEnv after calling close()" def __del__(self):
if not self.closed:
self.close() def _flatten_obs(obs):
assert isinstance(obs, (list, tuple))
assert len(obs) > 0 if isinstance(obs[0], dict):
keys = obs[0].keys()
return {k: np.stack([o[k] for o in obs]) for k in keys}
else:
return np.stack(obs) def _flatten_list(l):
assert isinstance(l, (list, tuple))
assert len(l) > 0
assert all([len(l_) > 0 for l_ in l]) return [l__ for l_ in l for l__ in l_]

该模块实现多进程下的多环境交互,这里假设每个进程可以与多个环境进行串行的交互,每个进程可以交互的环境个数为in_series,默认该值为1。

每个子进程的交互的环境个数为nremotes。

self.nremotes = nenvs // in_series

将生成环境env的函数进行nremotes个数的划分:

env_fns = np.array_split(env_fns, self.nremotes)

在这个模块中多进程使用multiprocessing来实现的。

进程自己的信息传递使用Pipe管道来实现:

self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(self.nremotes)])

从管道的生成可以看出一个有nremotes个子进程,也为其分别生成了nremotes个数个Pipe。

每个pipe都有两个端口,这里使用remotes和work_remotes来表示。

这里说明下pipe的两个端口都是可以进行读信息和存信息的操作的,不过如果有多个进程同时操作同一个端口会导致消息的操作出错,这样的话如果一个管道设计固定为两个进程进行消息传递可以自行指定两个端口分别给两个进程进行操作。

一个pipe的两个端口就像一个真实管道的两个端口一样。

生成nremotes个子进程,分别将nremotes个管道的两端和生成环境的函数可序列化包装后传递给子进程:

self.ps = [ctx.Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))
for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]

在这个类的设计中为1个主进程分别和nremotes个子进程进行交互。

在父进程中只对self.remotes进行操作,对其进行消息的读取和存储。

在子进程中只对self.work_remotes进行操作,对其进行消息的读取和存储。

在父进程中关闭对self.work_remotes的操作:

for remote in self.work_remotes:
remote.close()

在子进程中关闭对self.remotes的操作:

parent_remote.close()

在父进程中生成多个子进程,并设置子进程的模式为daemon,同时在生成子进程时要清空mpi的环境变量否则mpi的引入会导致子进程的挂起:

for p in self.ps:
p.daemon = True # if the main process crashes, we should not cause things to hang
with clear_mpi_env_vars():
p.start()

函数:

def step_async(self, actions):
def step_async(self, actions):

对动作进行划分使用pipe传递给对应的子进程并从管道中取出子进程返回的消息。

在对收到的observation后需要做处理,如果observation为dict类型则需要将多个环境的observation的对应的key保持不变将value进行堆叠,reward,done信息也是使用np.array进行堆叠,info则用list做集合。

如果各个环境的observation为np.array类型则对他们进行np.stack堆叠。

在图像绘制上继承父类vec_env.py中的特性:

    def render(self, mode='human'):
imgs = self.get_images()
bigimg = tile_images(imgs)
if mode == 'human':
self.get_viewer().imshow(bigimg)
return self.get_viewer().isopen
elif mode == 'rgb_array':
return bigimg
else:
raise NotImplementedError def get_viewer(self):
if self.viewer is None:
from gym.envs.classic_control import rendering
self.viewer = rendering.SimpleImageViewer()
return self.viewer

subproc_vec_env.py中的:

    def get_images(self):
self._assert_not_closed()
for pipe in self.remotes:
pipe.send(('render', None))
imgs = [pipe.recv() for pipe in self.remotes]
imgs = _flatten_list(imgs)
return imgs

可以看到render函数也是将收集到的各个子进程的图片拼接后进行绘制。

在worker函数中:

    def step_env(env, action):
ob, reward, done, info = env.step(action)
if done:
ob = env.reset()
return ob, reward, done, info

如果在单个进程中单个环境的done=True,则直接进行env.rest(),返回的状态为reset后的observation。

close函数:

    def close(self):
if self.closed:
return
if self.viewer is not None:
self.viewer.close()
self.close_extras()
self.closed = True
    def close_extras(self):
self.closed = True
if self.waiting:
for remote in self.remotes:
remote.recv()
for remote in self.remotes:
remote.send(('close', None))
for p in self.ps:
p.join()

在close操作时要通知子进程进行close操作,在主进程中需要阻塞等待子进程。

========================================

baselines算法库common/vec_env/subproc_vec_env.py模块分析的更多相关文章

  1. 图像滤镜艺术---ZPhotoEngine超级算法库

    原文:图像滤镜艺术---ZPhotoEngine超级算法库 一直以来,都有个想法,想要做一个属于自己的图像算法库,这个想法,在经过了几个月的努力之后,终于诞生了,这就是ZPhotoEngine算法库. ...

  2. scikit-learn 支持向量机算法库使用小结

    之前通过一个系列对支持向量机(以下简称SVM)算法的原理做了一个总结,本文从实践的角度对scikit-learn SVM算法库的使用做一个小结.scikit-learn SVM算法库封装了libsvm ...

  3. 【Python】【Web.py】详细解读Python的web.py框架下的application.py模块

    详细解读Python的web.py框架下的application.py模块   这篇文章主要介绍了Python的web.py框架下的application.py模块,作者深入分析了web.py的源码, ...

  4. 第三百零六节,Django框架,models.py模块,数据库操作——创建表、数据类型、索引、admin后台,补充Django目录说明以及全局配置文件配置

    Django框架,models.py模块,数据库操作——创建表.数据类型.索引.admin后台,补充Django目录说明以及全局配置文件配置 数据库配置 django默认支持sqlite,mysql, ...

  5. 使用织梦开源的分词算法库编写的YII获取分词扩展

    在编辑文章中,很多时候都需要自动根据文章内容获取关键字的功能,因此,本文主要是说明如何在yii中使用织梦开源的分词算法编写一个独立的扩展,可以在不同的模块中使用,步骤如下: 1 到这里下载其他朋友整理 ...

  6. 四 Django框架,models.py模块,数据库操作——创建表、数据类型、索引、admin后台,补充Django目录说明以及全局配置文件配置

    Django框架,models.py模块,数据库操作——创建表.数据类型.索引.admin后台,补充Django目录说明以及全局配置文件配置 数据库配置 django默认支持sqlite,mysql, ...

  7. mahout算法库(四)

    mahout算法库 分为三大块 1.聚类算法 2.协同过滤算法(一般用于推荐) 协同过滤算法也可以称为推荐算法!!! 3.分类算法 算法类 算法名 中文名 分类算法               Log ...

  8. 操作MySQL-数据库的安装及Pycharm模块的导入

    操作MySQL-数据库的安装及Pycharm模块的导入 1.基于pyCharm开发环境,在CMD控制台输入依次输入以下步骤: (1)pip3 install PyMySQL  < 安装 PyMy ...

  9. 痞子衡嵌入式:对比MbedTLS算法库纯软件实现与i.MXRT上DCP,CAAM硬件加速器实现性能差异

    大家好,我是痞子衡,是正经搞技术的痞子.今天痞子衡给大家介绍的是MbedTLS算法库纯软件实现与i.MXRT上DCP,CAAM硬件加速器实现性能差异. 近期有 i.MXRT 客户在集成 OTA SBL ...

  10. snowland-smx密码算法库

    snowland-smx密码算法库 一.snowland-smx密码算法库的介绍 snowland-smx是python实现的国密套件,对标python实现的gmssl,包含国密SM2,SM3,SM4 ...

随机推荐

  1. jwt 加密和解密demo

    jwt 加密和解密demo JSON Web Token(JWT)是一个非常轻巧的规范.这个规范允许我们使用 JWT 在用户和服务器之间传递安全可靠的信息.导入jar <dependency&g ...

  2. 关于Lecture2建立一个Git远程仓库的补充

    Smiling & Weeping ---- 心之何如,有似万丈迷津, 遥亘千里. 其中并无舟子可渡人, 除了自渡,他人爱莫能助. Git 远程仓库(Github) Git 并不像 SVN 那 ...

  3. GIT 生成变更历史文件清单

    脚本搞定git文件版本变化信息,解决部署种变更的审核和统计信息工作复杂问题 git diff --name-status --ignore-cr-at-eol --ignore-space-at-eo ...

  4. Linux/Unix-stty命令详解

    文章目录 介绍 stty命令的使用方法 stty的参数 我常用的选项 所有选项 介绍 stty用于查询和设置当前终端的配置. 如果你的终端回车不换行.输入命令不显示等各种奇葩问题,那么stty命令可以 ...

  5. 20-Docker镜像制作

    查看镜像构建的历史 docker image history 26a5 #查看镜像26a5的构建历史 使用commit命令构建镜像 使用commit命令可以将容器构建成镜像. 将容器webserver ...

  6. Linux设备模型:4、sysfs

    作者:wowo 发布于:2014-3-14 18:31 分类:统一设备模型 http://www.wowotech.net/device_model/dm_sysfs.html 前言 sysfs是一个 ...

  7. Ubuntu 查看用户历史记录

    Ubuntu 查看用户历史记录 1. 查看用户命令行历史记录 1. 查看当前登录账号所属用户的历史命令行记录 打开命令行,输入 history 就会看到当前登录账号所属用户的历史记录 2. 查看系统所 ...

  8. yb 课堂实战之视频列表接口开发+API权限路径规划 《三》

    开发JsonData工具类 package net.ybclass.online_ybclass.utils; public class JsonData { /** * 状态码,0表示成功过,1表示 ...

  9. P1387

    #include<iostream> #include<utility> using namespace std; typedef long long ll; #define ...

  10. 洛谷P1464

    搜索优化后是记忆化搜索,再优化就是dp了 #include <iostream> #include <utility> using namespace std; typedef ...