retro_wrappers.py模块代码:

  1. from collections import deque
  2. import cv2
  3. cv2.ocl.setUseOpenCL(False)
  4. from .atari_wrappers import WarpFrame, ClipRewardEnv, FrameStack, ScaledFloatFrame
  5. from .wrappers import TimeLimit
  6. import numpy as np
  7. import gym
  8.  
  9. class StochasticFrameSkip(gym.Wrapper):
  10. def __init__(self, env, n, stickprob):
  11. gym.Wrapper.__init__(self, env)
  12. self.n = n
  13. self.stickprob = stickprob
  14. self.curac = None
  15. self.rng = np.random.RandomState()
  16. self.supports_want_render = hasattr(env, "supports_want_render")
  17.  
  18. def reset(self, **kwargs):
  19. self.curac = None
  20. return self.env.reset(**kwargs)
  21.  
  22. def step(self, ac):
  23. done = False
  24. totrew = 0
  25. for i in range(self.n):
  26. # First step after reset, use action
  27. if self.curac is None:
  28. self.curac = ac
  29. # First substep, delay with probability=stickprob
  30. elif i==0:
  31. if self.rng.rand() > self.stickprob:
  32. self.curac = ac
  33. # Second substep, new action definitely kicks in
  34. elif i==1:
  35. self.curac = ac
  36. if self.supports_want_render and i<self.n-1:
  37. ob, rew, done, info = self.env.step(self.curac, want_render=False)
  38. else:
  39. ob, rew, done, info = self.env.step(self.curac)
  40. totrew += rew
  41. if done: break
  42. return ob, totrew, done, info
  43.  
  44. def seed(self, s):
  45. self.rng.seed(s)
  46.  
  47. class PartialFrameStack(gym.Wrapper):
  48. def __init__(self, env, k, channel=1):
  49. """
  50. Stack one channel (channel keyword) from previous frames
  51. """
  52. gym.Wrapper.__init__(self, env)
  53. shp = env.observation_space.shape
  54. self.channel = channel
  55. self.observation_space = gym.spaces.Box(low=0, high=255,
  56. shape=(shp[0], shp[1], shp[2] + k - 1),
  57. dtype=env.observation_space.dtype)
  58. self.k = k
  59. self.frames = deque([], maxlen=k)
  60. shp = env.observation_space.shape
  61.  
  62. def reset(self):
  63. ob = self.env.reset()
  64. assert ob.shape[2] > self.channel
  65. for _ in range(self.k):
  66. self.frames.append(ob)
  67. return self._get_ob()
  68.  
  69. def step(self, ac):
  70. ob, reward, done, info = self.env.step(ac)
  71. self.frames.append(ob)
  72. return self._get_ob(), reward, done, info
  73.  
  74. def _get_ob(self):
  75. assert len(self.frames) == self.k
  76. return np.concatenate([frame if i==self.k-1 else frame[:,:,self.channel:self.channel+1]
  77. for (i, frame) in enumerate(self.frames)], axis=2)
  78.  
  79. class Downsample(gym.ObservationWrapper):
  80. def __init__(self, env, ratio):
  81. """
  82. Downsample images by a factor of ratio
  83. """
  84. gym.ObservationWrapper.__init__(self, env)
  85. (oldh, oldw, oldc) = env.observation_space.shape
  86. newshape = (oldh//ratio, oldw//ratio, oldc)
  87. self.observation_space = gym.spaces.Box(low=0, high=255,
  88. shape=newshape, dtype=np.uint8)
  89.  
  90. def observation(self, frame):
  91. height, width, _ = self.observation_space.shape
  92. frame = cv2.resize(frame, (width, height), interpolation=cv2.INTER_AREA)
  93. if frame.ndim == 2:
  94. frame = frame[:,:,None]
  95. return frame
  96.  
  97. class Rgb2gray(gym.ObservationWrapper):
  98. def __init__(self, env):
  99. """
  100. Downsample images by a factor of ratio
  101. """
  102. gym.ObservationWrapper.__init__(self, env)
  103. (oldh, oldw, _oldc) = env.observation_space.shape
  104. self.observation_space = gym.spaces.Box(low=0, high=255,
  105. shape=(oldh, oldw, 1), dtype=np.uint8)
  106.  
  107. def observation(self, frame):
  108. frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
  109. return frame[:,:,None]
  110.  
  111. class MovieRecord(gym.Wrapper):
  112. def __init__(self, env, savedir, k):
  113. gym.Wrapper.__init__(self, env)
  114. self.savedir = savedir
  115. self.k = k
  116. self.epcount = 0
  117. def reset(self):
  118. if self.epcount % self.k == 0:
  119. self.env.unwrapped.movie_path = self.savedir
  120. else:
  121. self.env.unwrapped.movie_path = None
  122. self.env.unwrapped.movie = None
  123. self.epcount += 1
  124. return self.env.reset()
  125.  
  126. class AppendTimeout(gym.Wrapper):
  127. def __init__(self, env):
  128. gym.Wrapper.__init__(self, env)
  129. self.action_space = env.action_space
  130. self.timeout_space = gym.spaces.Box(low=np.array([0.0]), high=np.array([1.0]), dtype=np.float32)
  131. self.original_os = env.observation_space
  132. if isinstance(self.original_os, gym.spaces.Dict):
  133. import copy
  134. ordered_dict = copy.deepcopy(self.original_os.spaces)
  135. ordered_dict['value_estimation_timeout'] = self.timeout_space
  136. self.observation_space = gym.spaces.Dict(ordered_dict)
  137. self.dict_mode = True
  138. else:
  139. self.observation_space = gym.spaces.Dict({
  140. 'original': self.original_os,
  141. 'value_estimation_timeout': self.timeout_space
  142. })
  143. self.dict_mode = False
  144. self.ac_count = None
  145. while 1:
  146. if not hasattr(env, "_max_episode_steps"): # Looking for TimeLimit wrapper that has this field
  147. env = env.env
  148. continue
  149. break
  150. self.timeout = env._max_episode_steps
  151.  
  152. def step(self, ac):
  153. self.ac_count += 1
  154. ob, rew, done, info = self.env.step(ac)
  155. return self._process(ob), rew, done, info
  156.  
  157. def reset(self):
  158. self.ac_count = 0
  159. return self._process(self.env.reset())
  160.  
  161. def _process(self, ob):
  162. fracmissing = 1 - self.ac_count / self.timeout
  163. if self.dict_mode:
  164. ob['value_estimation_timeout'] = fracmissing
  165. else:
  166. return { 'original': ob, 'value_estimation_timeout': fracmissing }
  167.  
  168. class StartDoingRandomActionsWrapper(gym.Wrapper):
  169. """
  170. Warning: can eat info dicts, not good if you depend on them
  171. """
  172. def __init__(self, env, max_random_steps, on_startup=True, every_episode=False):
  173. gym.Wrapper.__init__(self, env)
  174. self.on_startup = on_startup
  175. self.every_episode = every_episode
  176. self.random_steps = max_random_steps
  177. self.last_obs = None
  178. if on_startup:
  179. self.some_random_steps()
  180.  
  181. def some_random_steps(self):
  182. self.last_obs = self.env.reset()
  183. n = np.random.randint(self.random_steps)
  184. #print("running for random %i frames" % n)
  185. for _ in range(n):
  186. self.last_obs, _, done, _ = self.env.step(self.env.action_space.sample())
  187. if done: self.last_obs = self.env.reset()
  188.  
  189. def reset(self):
  190. return self.last_obs
  191.  
  192. def step(self, a):
  193. self.last_obs, rew, done, info = self.env.step(a)
  194. if done:
  195. self.last_obs = self.env.reset()
  196. if self.every_episode:
  197. self.some_random_steps()
  198. return self.last_obs, rew, done, info
  199.  
  200. def make_retro(*, game, state=None, max_episode_steps=4500, **kwargs):
  201. import retro
  202. if state is None:
  203. state = retro.State.DEFAULT
  204. env = retro.make(game, state, **kwargs)
  205. env = StochasticFrameSkip(env, n=4, stickprob=0.25)
  206. if max_episode_steps is not None:
  207. env = TimeLimit(env, max_episode_steps=max_episode_steps)
  208. return env
  209.  
  210. def wrap_deepmind_retro(env, scale=True, frame_stack=4):
  211. """
  212. Configure environment for retro games, using config similar to DeepMind-style Atari in wrap_deepmind
  213. """
  214. env = WarpFrame(env)
  215. env = ClipRewardEnv(env)
  216. if frame_stack > 1:
  217. env = FrameStack(env, frame_stack)
  218. if scale:
  219. env = ScaledFloatFrame(env)
  220. return env
  221.  
  222. class SonicDiscretizer(gym.ActionWrapper):
  223. """
  224. Wrap a gym-retro environment and make it use discrete
  225. actions for the Sonic game.
  226. """
  227. def __init__(self, env):
  228. super(SonicDiscretizer, self).__init__(env)
  229. buttons = ["B", "A", "MODE", "START", "UP", "DOWN", "LEFT", "RIGHT", "C", "Y", "X", "Z"]
  230. actions = [['LEFT'], ['RIGHT'], ['LEFT', 'DOWN'], ['RIGHT', 'DOWN'], ['DOWN'],
  231. ['DOWN', 'B'], ['B']]
  232. self._actions = []
  233. for action in actions:
  234. arr = np.array([False] * 12)
  235. for button in action:
  236. arr[buttons.index(button)] = True
  237. self._actions.append(arr)
  238. self.action_space = gym.spaces.Discrete(len(self._actions))
  239.  
  240. def action(self, a): # pylint: disable=W0221
  241. return self._actions[a].copy()
  242.  
  243. class RewardScaler(gym.RewardWrapper):
  244. """
  245. Bring rewards to a reasonable scale for PPO.
  246. This is incredibly important and effects performance
  247. drastically.
  248. """
  249. def __init__(self, env, scale=0.01):
  250. super(RewardScaler, self).__init__(env)
  251. self.scale = scale
  252.  
  253. def reward(self, reward):
  254. return reward * self.scale
  255.  
  256. class AllowBacktracking(gym.Wrapper):
  257. """
  258. Use deltas in max(X) as the reward, rather than deltas
  259. in X. This way, agents are not discouraged too heavily
  260. from exploring backwards if there is no way to advance
  261. head-on in the level.
  262. """
  263. def __init__(self, env):
  264. super(AllowBacktracking, self).__init__(env)
  265. self._cur_x = 0
  266. self._max_x = 0
  267.  
  268. def reset(self, **kwargs): # pylint: disable=E0202
  269. self._cur_x = 0
  270. self._max_x = 0
  271. return self.env.reset(**kwargs)
  272.  
  273. def step(self, action): # pylint: disable=E0202
  274. obs, rew, done, info = self.env.step(action)
  275. self._cur_x += rew
  276. rew = max(0, self._cur_x - self._max_x)
  277. self._max_x = max(self._max_x, self._cur_x)
  278. return obs, rew, done, info

该模块顾名思义就是为retro环境库做包装的。

该模块对环境的包装与atari库的包装相似但是也有所不同,retro库最有名的环境应该就是超级马里奥、俄罗斯方块和刺猬sonic了。

由于该模块需要使用opencv对图片进行处理因此文件开始处进行了opencl禁用设置,以防止与cuda冲突。

  1. cv2.ocl.setUseOpenCL(False)
  1. class StochasticFrameSkip(gym.Wrapper):
  2. def __init__(self, env, n, stickprob):
  3. gym.Wrapper.__init__(self, env)
  4. self.n = n
  5. self.stickprob = stickprob
  6. self.curac = None
  7. self.rng = np.random.RandomState()
  8. self.supports_want_render = hasattr(env, "supports_want_render")
  9.  
  10. def reset(self, **kwargs):
  11. self.curac = None
  12. return self.env.reset(**kwargs)
  13.  
  14. def step(self, ac):
  15. done = False
  16. totrew = 0
  17. for i in range(self.n):
  18. # First step after reset, use action
  19. if self.curac is None:
  20. self.curac = ac
  21. # First substep, delay with probability=stickprob
  22. elif i==0:
  23. if self.rng.rand() > self.stickprob:
  24. self.curac = ac
  25. # Second substep, new action definitely kicks in
  26. elif i==1:
  27. self.curac = ac
  28. if self.supports_want_render and i<self.n-1:
  29. ob, rew, done, info = self.env.step(self.curac, want_render=False)
  30. else:
  31. ob, rew, done, info = self.env.step(self.curac)
  32. totrew += rew
  33. if done: break
  34. return ob, totrew, done, info
  35.  
  36. def seed(self, s):
  37. self.rng.seed(s)

包装类StochasticFrameSkip的重点在step函数上:

该类采用frameSkip技术,也就是说收到一个动作后会与环境重复交互n次,但是与其他的frameSkip不同,这里采用的是StochasticFrameSkip,也就是在收到动作后第一个交互动作以概率stickprob保持上一次与环境交互的动作而不是此次接收到的动作。

从第二次动作,也就是i==1以后与环境进行的交互动作都是此次调用step函数时接收到的动作。

这里有一个小点,就是如果step的时候需要绘图操作,即render,只会在n次与环境交互的最后一次进行绘图render 。

由于接收都一次动作而与环境进行了n次交互,因此最终的reward为这n次获得的reward之和。

  1. class PartialFrameStack(gym.Wrapper):
  2. def __init__(self, env, k, channel=1):
  3. """
  4. Stack one channel (channel keyword) from previous frames
  5. """
  6. gym.Wrapper.__init__(self, env)
  7. shp = env.observation_space.shape
  8. self.channel = channel
  9. self.observation_space = gym.spaces.Box(low=0, high=255,
  10. shape=(shp[0], shp[1], shp[2] + k - 1),
  11. dtype=env.observation_space.dtype)
  12. self.k = k
  13. self.frames = deque([], maxlen=k)
  14. shp = env.observation_space.shape
  15.  
  16. def reset(self):
  17. ob = self.env.reset()
  18. assert ob.shape[2] > self.channel
  19. for _ in range(self.k):
  20. self.frames.append(ob)
  21. return self._get_ob()
  22.  
  23. def step(self, ac):
  24. ob, reward, done, info = self.env.step(ac)
  25. self.frames.append(ob)
  26. return self._get_ob(), reward, done, info
  27.  
  28. def _get_ob(self):
  29. assert len(self.frames) == self.k
  30. return np.concatenate([frame if i==self.k-1 else frame[:,:,self.channel:self.channel+1]
  31. for (i, frame) in enumerate(self.frames)], axis=2)

将K帧游戏图片在通道channel维度上进行拼接。

需要注意的是这个环境包装类并不是传统的维度拼接,而是一种部分通道拼接,一个需要拼接的图片帧为K,在前K-1个帧图片拼接时是只选择指定的通道channel的,只有最后一帧,第K帧拼接时才选取所有通道,即:

  1. frame if i==self.k-1 else frame[:,:,self.channel:self.channel+1
  1. class Downsample(gym.ObservationWrapper):
  2. def __init__(self, env, ratio):
  3. """
  4. Downsample images by a factor of ratio
  5. """
  6. gym.ObservationWrapper.__init__(self, env)
  7. (oldh, oldw, oldc) = env.observation_space.shape
  8. newshape = (oldh//ratio, oldw//ratio, oldc)
  9. self.observation_space = gym.spaces.Box(low=0, high=255,
  10. shape=newshape, dtype=np.uint8)
  11.  
  12. def observation(self, frame):
  13. height, width, _ = self.observation_space.shape
  14. frame = cv2.resize(frame, (width, height), interpolation=cv2.INTER_AREA)
  15. if frame.ndim == 2:
  16. frame = frame[:,:,None]
  17. return frame

observation包装类,对图片大小进行缩放,需要注意的是该类包装后的observation都是带有channel维度的np.array,也就是说返回的observation都是维度为3的。

如果observation维度为2则为channel维度进行扩充:

  1. if frame.ndim == 2:
  2. frame = frame[:,:,None]
  1. class Rgb2gray(gym.ObservationWrapper):
  2. def __init__(self, env):
  3. """
  4. Downsample images by a factor of ratio
  5. """
  6. gym.ObservationWrapper.__init__(self, env)
  7. (oldh, oldw, _oldc) = env.observation_space.shape
  8. self.observation_space = gym.spaces.Box(low=0, high=255,
  9. shape=(oldh, oldw, 1), dtype=np.uint8)
  10.  
  11. def observation(self, frame):
  12. frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
  13. return frame[:,:,None]

使用opencv对rgb图片转换为gray灰度图,需要注意的是最后返回的observation是对channel维度进行扩充过的,也就是observation返回值都是3个维度的。

维度扩充操作:

  1. return frame[:,:,None]
  1. class MovieRecord(gym.Wrapper):
  2. def __init__(self, env, savedir, k):
  3. gym.Wrapper.__init__(self, env)
  4. self.savedir = savedir
  5. self.k = k
  6. self.epcount = 0
  7. def reset(self):
  8. if self.epcount % self.k == 0:
  9. self.env.unwrapped.movie_path = self.savedir
  10. else:
  11. self.env.unwrapped.movie_path = None
  12. self.env.unwrapped.movie = None
  13. self.epcount += 1
  14. return self.env.reset()

在一定episodes周期上在进行reset操作时设置env.unwrapped.movie_path变量。

在这里设置视频保存地址的具体作用还不知晓,在整个baselines项目中也没有查找到具体使用。

  1. class AppendTimeout(gym.Wrapper):
  2. def __init__(self, env):
  3. gym.Wrapper.__init__(self, env)
  4. self.action_space = env.action_space
  5. self.timeout_space = gym.spaces.Box(low=np.array([0.0]), high=np.array([1.0]), dtype=np.float32)
  6. self.original_os = env.observation_space
  7. if isinstance(self.original_os, gym.spaces.Dict):
  8. import copy
  9. ordered_dict = copy.deepcopy(self.original_os.spaces)
  10. ordered_dict['value_estimation_timeout'] = self.timeout_space
  11. self.observation_space = gym.spaces.Dict(ordered_dict)
  12. self.dict_mode = True
  13. else:
  14. self.observation_space = gym.spaces.Dict({
  15. 'original': self.original_os,
  16. 'value_estimation_timeout': self.timeout_space
  17. })
  18. self.dict_mode = False
  19. self.ac_count = None
  20. while 1:
  21. if not hasattr(env, "_max_episode_steps"): # Looking for TimeLimit wrapper that has this field
  22. env = env.env
  23. continue
  24. break
  25. self.timeout = env._max_episode_steps
  26.  
  27. def step(self, ac):
  28. self.ac_count += 1
  29. ob, rew, done, info = self.env.step(ac)
  30. return self._process(ob), rew, done, info
  31.  
  32. def reset(self):
  33. self.ac_count = 0
  34. return self._process(self.env.reset())
  35.  
  36. def _process(self, ob):
  37. fracmissing = 1 - self.ac_count / self.timeout
  38. if self.dict_mode:
  39. ob['value_estimation_timeout'] = fracmissing
  40. else:
  41. return { 'original': ob, 'value_estimation_timeout': fracmissing }

如果observation为Dict类型则为其添加key值为'value_estimation_timeout',value值为一个episode内当前步数距离最大episode步数的比值。

如果observation为np.array类型,则将其转为key值为'original'的字典,同时添加key值为'value_estimation_timeout',value值为一个episode内当前步数距离最大episode步数的比值。

改类主要对observation进行 包装,将observation转为dict类型,同时添加key为'value_estimation_timeout' 。

对传入的env变量判断是否有_max_episode_steps变量,并不断循环env=env.env,来判断最内层的env的最大episode steps 。

该类主要为在返回的observation中记录当前步数与最大episode steps之间的距离。

  1. class StartDoingRandomActionsWrapper(gym.Wrapper):
  2. """
  3. Warning: can eat info dicts, not good if you depend on them
  4. """
  5. def __init__(self, env, max_random_steps, on_startup=True, every_episode=False):
  6. gym.Wrapper.__init__(self, env)
  7. self.on_startup = on_startup
  8. self.every_episode = every_episode
  9. self.random_steps = max_random_steps
  10. self.last_obs = None
  11. if on_startup:
  12. self.some_random_steps()
  13.  
  14. def some_random_steps(self):
  15. self.last_obs = self.env.reset()
  16. n = np.random.randint(self.random_steps)
  17. #print("running for random %i frames" % n)
  18. for _ in range(n):
  19. self.last_obs, _, done, _ = self.env.step(self.env.action_space.sample())
  20. if done: self.last_obs = self.env.reset()
  21.  
  22. def reset(self):
  23. return self.last_obs
  24.  
  25. def step(self, a):
  26. self.last_obs, rew, done, info = self.env.step(a)
  27. if done:
  28. self.last_obs = self.env.reset()
  29. if self.every_episode:
  30. self.some_random_steps()
  31. return self.last_obs, rew, done, info

设置是否在一个episode开始时进行一定步数的随机动作。

主要代码:

  1. for _ in range(n):
  2. self.last_obs, _, done, _ = self.env.step(self.env.action_space.sample())
  3. if done: self.last_obs = self.env.reset()

该类可以设置在类初始第一个episode的时候是否进行一定步数的随机动作,也可以设置是否在每个episode开始的时候进行一定步数的随机动作。

  1. def make_retro(*, game, state=None, max_episode_steps=4500, **kwargs):
  2. import retro
  3. if state is None:
  4. state = retro.State.DEFAULT
  5. env = retro.make(game, state, **kwargs)
  6. env = StochasticFrameSkip(env, n=4, stickprob=0.25)
  7. if max_episode_steps is not None:
  8. env = TimeLimit(env, max_episode_steps=max_episode_steps)
  9. return env

对前面的包装类进行组合。

对retro生成的环境使用StochasticFrameSkip和TimeLimit两个类进行包装。

  1. def wrap_deepmind_retro(env, scale=True, frame_stack=4):
  2. """
  3. Configure environment for retro games, using config similar to DeepMind-style Atari in wrap_deepmind
  4. """
  5. env = WarpFrame(env)
  6. env = ClipRewardEnv(env)
  7. if frame_stack > 1:
  8. env = FrameStack(env, frame_stack)
  9. if scale:
  10. env = ScaledFloatFrame(env)
  11. return env

使用atari游戏的环境包装类对retro游戏进行包装。

WarpFrame对图片进行灰度化和裁剪。

ClipReward对奖励值裁剪为-1, 0, +1 。

FrameStack对k个图片在通道维度上进行堆叠。

ScaledFloatFrame将图片np.array的数值从0到255的uint8转为0到1的float32。

  1. class SonicDiscretizer(gym.ActionWrapper):
  2. """
  3. Wrap a gym-retro environment and make it use discrete
  4. actions for the Sonic game.
  5. """
  6. def __init__(self, env):
  7. super(SonicDiscretizer, self).__init__(env)
  8. buttons = ["B", "A", "MODE", "START", "UP", "DOWN", "LEFT", "RIGHT", "C", "Y", "X", "Z"]
  9. actions = [['LEFT'], ['RIGHT'], ['LEFT', 'DOWN'], ['RIGHT', 'DOWN'], ['DOWN'],
  10. ['DOWN', 'B'], ['B']]
  11. self._actions = []
  12. for action in actions:
  13. arr = np.array([False] * 12)
  14. for button in action:
  15. arr[buttons.index(button)] = True
  16. self._actions.append(arr)
  17. self.action_space = gym.spaces.Discrete(len(self._actions))
  18.  
  19. def action(self, a): # pylint: disable=W0221
  20. return self._actions[a].copy()

对动作action进行包装。

环境接收的外部传入的动作为:

  1. actions = [['LEFT'], ['RIGHT'], ['LEFT', 'DOWN'], ['RIGHT', 'DOWN'], ['DOWN'],
  2. ['DOWN', 'B'], ['B']]

接收的动作为整数,0代表的为['LEFT'], 1代表的为['RIGHT'],2代表的为['LEFT', 'DOWN'],等等......

可以知道外部传入的动作为0到6的数字,而内部retro环境能识别的动作为:

  1. buttons = ["B", "A", "MODE", "START", "UP", "DOWN", "LEFT", "RIGHT", "C", "Y", "X", "Z"]

共有12个,这里传给内部retro环境的动作使用one-hot编码,但是不同的是允许两个动作的组合,这样正好可以与外部传入的动作0到6所对应。

  1. for action in actions:
  2. arr = np.array([False] * 12)
  3. for button in action:
  4. arr[buttons.index(button)] = True
  5. self._actions.append(arr)
  1. class RewardScaler(gym.RewardWrapper):
  2. """
  3. Bring rewards to a reasonable scale for PPO.
  4. This is incredibly important and effects performance
  5. drastically.
  6. """
  7. def __init__(self, env, scale=0.01):
  8. super(RewardScaler, self).__init__(env)
  9. self.scale = scale
  10.  
  11. def reward(self, reward):
  12. return reward * self.scale

对环境的reward进行包装,对reward值进行缩放,根据注释这个包装类主要为PPO算法提供并且可以提升显著的算法性能。

  1. class AllowBacktracking(gym.Wrapper):
  2. """
  3. Use deltas in max(X) as the reward, rather than deltas
  4. in X. This way, agents are not discouraged too heavily
  5. from exploring backwards if there is no way to advance
  6. head-on in the level.
  7. """
  8. def __init__(self, env):
  9. super(AllowBacktracking, self).__init__(env)
  10. self._cur_x = 0
  11. self._max_x = 0
  12.  
  13. def reset(self, **kwargs): # pylint: disable=E0202
  14. self._cur_x = 0
  15. self._max_x = 0
  16. return self.env.reset(**kwargs)
  17.  
  18. def step(self, action): # pylint: disable=E0202
  19. obs, rew, done, info = self.env.step(action)
  20. self._cur_x += rew
  21. rew = max(0, self._cur_x - self._max_x)
  22. self._max_x = max(self._max_x, self._cur_x)
  23. return obs, rew, done, info

改包装类个人理解主要是为超级马里奥游戏提供,主要作用是通过对reset函数和step函数进行包装从而实现对reward的定制化。

由于该类是对reward的包装,而个人对于这类游戏的reward设计并不是很了解,因此只能从代码反推内部环境类对reward的设计:

假设游戏是超级马里奥,agent(也就是马里奥)向右移动reward为正数,如果向左移动则reward为负数,因为向右移动是朝向游戏终点移动,而向左移动是朝远离游戏终点的方向移动。

这里self._max_x的设计是为了记录agent移动历史中最靠右的坐标,self._cur_x是对历史获得reward的求和,假设当前agent处在历史最右坐标self._max_x的左侧,那么此刻无论agent的上步动作是什么它所获得的reward必然为0,因为:

  1. rew = max(0, self._cur_x - self._max_x)

在某种程度上可以理解这个reward的设计就是不鼓励agent朝向已走过的路进行探索(self._max_x的左侧),但是无论agent如何执行动作只要它处于的位置在self._max_x的左侧获得的reward都为0。

===========================================

baselines算法库common/retro_wrappers.py模块分析的更多相关文章

  1. openstack 中 log模块分析

    1 . 所在模块,一般在openstack/common/log.py,其实最主要的还是调用了python中的logging模块: 入口函数在 def setup(product_name, vers ...

  2. 【Python】【Web.py】详细解读Python的web.py框架下的application.py模块

    详细解读Python的web.py框架下的application.py模块   这篇文章主要介绍了Python的web.py框架下的application.py模块,作者深入分析了web.py的源码, ...

  3. Python标准库笔记(9) — functools模块

    functools 作用于函数的函数 functools 模块提供用于调整或扩展函数和其他可调用对象的工具,而无需完全重写它们. 装饰器 partial 类是 functools 模块提供的主要工具, ...

  4. python标准库介绍——12 time 模块详解

    ==time 模块== ``time`` 模块提供了一些处理日期和一天内时间的函数. 它是建立在 C 运行时库的简单封装. 给定的日期和时间可以被表示为浮点型(从参考时间, 通常是 1970.1.1 ...

  5. mahout算法库(四)

    mahout算法库 分为三大块 1.聚类算法 2.协同过滤算法(一般用于推荐) 协同过滤算法也可以称为推荐算法!!! 3.分类算法 算法类 算法名 中文名 分类算法               Log ...

  6. scikit-learn 支持向量机算法库使用小结

    之前通过一个系列对支持向量机(以下简称SVM)算法的原理做了一个总结,本文从实践的角度对scikit-learn SVM算法库的使用做一个小结.scikit-learn SVM算法库封装了libsvm ...

  7. OpenRisc-43-or1200的IF模块分析

    引言 “喂饱饥饿的CPU”,是计算机体系结构设计者时刻要考虑的问题.要解决这个问题,方法大体可分为两部分,第一就是利用principle of locality而引进的cache技术,缩短取指时间,第 ...

  8. 【转】python模块分析之unittest测试(五)

    [转]python模块分析之unittest测试(五) 系列文章 python模块分析之random(一) python模块分析之hashlib加密(二) python模块分析之typing(三) p ...

  9. 【转】python模块分析之hashlib加密(二)

    [转]python模块分析之hashlib加密(二) hashlib模块是用来对字符串进行hash加密的模块,明文与密文是一一对应不变的关系:用于注册.登录时用户名.密码等加密使用.一.函数分析:1. ...

  10. 【转】python之random模块分析(一)

    [转]python之random模块分析(一) random是python产生伪随机数的模块,随机种子默认为系统时钟.下面分析模块中的方法: 1.random.randint(start,stop): ...

随机推荐

  1. 支付宝签名和验签使用JSONObject是最优解。json字符串顺序和==符号都一致演示代码

    支付宝签名和验签使用JSONObject是最优解.json字符串顺序和==符号都一致演示代码 支付宝spi接口设计验签和返回结果加签注意点,支付宝使用JSONObject对象https://www.c ...

  2. 【Java异常】Variable used in lambda expression should be final or effectively final

    [Java异常]Variable used in lambda expression should be final or effectively final 从字面上来理解这句话,意思是:*lamb ...

  3. 字符数组转换及数字求和 java8 lambda表达式 demo

    public static void main(String[] args) throws IllegalAccessException { //字符串转换为数字且每个加上100,输出. String ...

  4. 一文了解Spark引擎的优势及应用场景

    Spark引擎诞生的背景 Spark的发展历程可以追溯到2009年,由加州大学伯克利分校的AMPLab研究团队发起.成为Apache软件基金会的孵化项目后,于2012年发布了第一个稳定版本. 以下是S ...

  5. const 和 volatile 指针

    关键字 const 和 volatile 规定了指针的处理方式: const 规定指针在初始化后是受保护的,不能够再修改. volatile 规定了变量的值能够被用户应用程序外部的操作所修改. 因此, ...

  6. C# 语言在AGI 赛道上能做什么

    自从2022年11月OpenAI正式对外发布ChatGPT依赖,AGI 这条赛道上就挤满了重量级的选手,各大头部公司纷纷下场布局.原本就在机器学习.深度学习领域占据No.1的Python语言更是继续稳 ...

  7. [flask]统一API响应格式

    前言 在设计API返回内容时,通常需要与前端约定好API返回响应体内容的格式.这样方便前端进行数据反序列化时相应的解析处理,也方便其它服务调用.不同公司有不同的响应内容规范要求,这里以常见的JSON响 ...

  8. 解决 Visual C++ 17.5 __cplusplus 始终为 199711L 的问题

    00. 软件环境 Visual Studio 2022, Visual C++, Version 17.5.4 01. 问题描述 在应用 https://github.com/ToniLipponen ...

  9. pytest_重写pytest_sessionfinish方法的执行顺序_结合报告生成到发送邮件

    背景: Python + pytest+pytest-testreport生成测试报告,到了生成报告之后,想要发送邮件,之前的方案是配合Jenkins,配置报告的路径进行发送 如果是平时的跑的项目,没 ...

  10. hdu4135题解 容斥

    Problem Description Given a number N, you are asked to count the number of integers between A and B ...