Hi all,
I made a custom env, which is compatible with Gym and also works with RLlib. Now I would like to use TorchRL, and need to create my custom env. Here is how I made my EnvMaker class:
from gym_floorplan.envs.master_env import SpaceLayoutGym
class EnvMaker:
def __init__(self, env_config):
self.env_config = env_config
self.env_name = env_config['env_name']
self._register_my_env(env_config=env_config)
def _register_my_env(self, env_name: str = 'SpaceLayoutGym-v0', env_config={}):
gym.envs.register(id='SpaceLayoutGym-v0',
entry_point=self._env_creator,
kwargs={'env_config': env_config})
def _env_creator(self, env_config):
return SpaceLayoutGym(env_config)
def make_dummy(self):
test_env = gym.make(self.env_name)
test_env.device = self.env_config['device']
test_env = GymWrapper(test_env)
test_env = self._transforme_env(test_env)
self.obs_norm_sd = self._get_norm_stats(test_env)
return test_env
def make(self, parallel=False, obs_norm_sd=None):
def create_custom_env():
env = gym.make(self.env_config['env_name'])
env.device = self.env_config['device']
# env.batch_size = self.env_config['batch_size']
return GymWrapper(env)
if parallel:
env = ParallelEnv(self.env_config['num_workers'], EnvCreator(create_custom_env))
else:
env = create_custom_env()
env = self._transforme_env(env, self.obs_norm_sd)
check_env_specs(env)
rollout = env.rollout(3)
print("rollout of three steps:", rollout)
return env
def _transforme_env(self, env, obs_norm_sd=None):
if obs_norm_sd is None:
obs_norm_sd = {"standard_normal": True}
env_ = TransformedEnv(
env,
Compose(
StepCounter(),
DoubleToFloat(
in_keys=["observation"],
),
ObservationNorm(in_keys=["observation"], **obs_norm_sd), # TODO comment this for SpaceLayoutGym
),
# device=self.env_config['device'],
)
return env_
def _get_norm_stats(self, env_):
if 'Cnn' in self.env_config['model_name']:
env_.transform[-1].init_stats(num_iter=10, cat_dim=0, reduce_dim=[-1, -2, -4], keep_dims=(-1, -2))
else:
env_.transform[-1].init_stats(num_iter=10, reduce_dim=0, cat_dim=0)
obs_norm_sd = env_.transform[-1].state_dict()
print("state dict of the observation norm:", obs_norm_sd)
return obs_norm_sd
Now I want to make single and parallel env:
env_maker = EnvMaker(fenv_config) # fenv_config is my env config
# I first make a dummy env in order to compute the obs stats
dummy_env = env_maker.make_dummy() # works well
# Then I make my env using obs stats
env = env_maker.make(parallel=False) # this works well too
# Now I like to make parallel env
envs = env_maker.make(parallel=True) # but this does not work. It does not recognize my fenv_config
I wonder if anyone knows how to correctly pass the env config while wrap custom env with ParraleEnv? Or are there any simpler approaches to making parallel env in TorchRL?
Thanks!