How to make ParrallelEnv from a custom env that requires env_config?

Hi all,

I made a custom env, which is compatible with Gym and also works with RLlib. Now I would like to use TorchRL, and need to create my custom env. Here is how I made my EnvMaker class:

from gym_floorplan.envs.master_env import SpaceLayoutGym

class EnvMaker:
    def __init__(self, env_config):
        self.env_config = env_config
        self.env_name = env_config['env_name']
        
        self._register_my_env(env_config=env_config)
            
    
    def _register_my_env(self, env_name: str = 'SpaceLayoutGym-v0', env_config={}):
        gym.envs.register(id='SpaceLayoutGym-v0',
                          entry_point=self._env_creator,
                          kwargs={'env_config': env_config})
    
     
        
    def _env_creator(self, env_config):
        return SpaceLayoutGym(env_config)
    
    
    
    def make_dummy(self):
        test_env = gym.make(self.env_name)
        test_env.device = self.env_config['device']
        test_env = GymWrapper(test_env)
        test_env = self._transforme_env(test_env)
        
        self.obs_norm_sd = self._get_norm_stats(test_env)
        
        return test_env
    
    
    
    def make(self, parallel=False, obs_norm_sd=None):
        def create_custom_env():
            env = gym.make(self.env_config['env_name'])
            env.device = self.env_config['device']
            # env.batch_size = self.env_config['batch_size']
            return GymWrapper(env)
        
        if parallel:
            env = ParallelEnv(self.env_config['num_workers'], EnvCreator(create_custom_env))
        else:
            env = create_custom_env()
    
        env = self._transforme_env(env, self.obs_norm_sd)
        check_env_specs(env)
        
        rollout = env.rollout(3)
        print("rollout of three steps:", rollout)
        return env
        
        
    def _transforme_env(self, env, obs_norm_sd=None):
        if obs_norm_sd is None:
            obs_norm_sd = {"standard_normal": True}
            
        env_ = TransformedEnv(
                env,
                Compose(
                    StepCounter(),
                    DoubleToFloat(
                        in_keys=["observation"],
                    ),
                    ObservationNorm(in_keys=["observation"], **obs_norm_sd), # TODO comment this for SpaceLayoutGym
                ),
                # device=self.env_config['device'],
            )
        return env_
    
    
    
    def _get_norm_stats(self, env_):
        if 'Cnn' in self.env_config['model_name']:
            env_.transform[-1].init_stats(num_iter=10, cat_dim=0, reduce_dim=[-1, -2, -4], keep_dims=(-1, -2))
        else:
            env_.transform[-1].init_stats(num_iter=10, reduce_dim=0, cat_dim=0)
        obs_norm_sd = env_.transform[-1].state_dict()
        print("state dict of the observation norm:", obs_norm_sd)
        return obs_norm_sd

Now I want to make single and parallel env:

env_maker = EnvMaker(fenv_config) # fenv_config is my env config
# I first make a dummy env in order to compute the obs stats
dummy_env = env_maker.make_dummy() # works well
# Then I make my env using obs stats
env = env_maker.make(parallel=False) # this works well too
# Now I like to make parallel env 
envs = env_maker.make(parallel=True) # but this does not work. It does not recognize my fenv_config

I wonder if anyone knows how to correctly pass the env config while wrap custom env with ParraleEnv? Or are there any simpler approaches to making parallel env in TorchRL?

Thanks!

Have you tried

        def create_custom_env(env_name=self.env_config['env_name'], device=self.env_config['device'], batch_size=self.env_config['batch_size']):
            env = gym.make(env_name)
            env.device = device
            # env.batch_size = batch_size
            return GymWrapper(env)

usually kwargs help me code this sort of stuff

Hi @vmoens,
Many thanks for your quick reply.

I tested what you suggested, but it did not work.

However, I found a kind of workaround as follows. I know it might not be the best way, but seems working.

from gym_floorplan.envs.master_env import SpaceLayoutGym

class EnvMaker:
    def __init__(self, fenv_config):
        self.fenv_config = fenv_config
        self.env_name = fenv_config['env_name']
        
        self._register_my_env(env_config=fenv_config)
            
    
    def _register_my_env(self, env_name: str = 'SpaceLayoutGym-v0', env_config={}):
        gym.envs.register(id='SpaceLayoutGym-v0',
                          entry_point=self._env_creator,
                          kwargs={'env_config': env_config})
    
     
        
    def _env_creator(self, env_config):
        return SpaceLayoutGym(env_config)
    
    
    
    def make_dummy(self):
        test_env = gym.make(self.env_name)
        test_env.device = self.fenv_config['device']
        test_env = GymWrapper(test_env)
        test_env = self._transforme_env(test_env)
        self.obs_norm_sd = self._get_norm_stats(test_env)
        return test_env
    
    
    
    def make_single(self):
        def create_custom_env():
            env = gym.make(self.fenv_config['env_name'])
            env.device = self.fenv_config['device']
            env.batch_size = self.fenv_config['batch_size']
            return GymWrapper(env)
        env = create_custom_env()
        env = self._transforme_env(env, self.obs_norm_sd)
        check_env_specs(env)
        rollout = env.rollout(3)
        print("rollout of three steps:", rollout)
        return env
    
    
    
    def make_parallel(self):
        def create_custom_env():
            env = GymWrapper(SpaceLayoutGym(self.fenv_config))
            return env
        env = ParallelEnv(self.fenv_config['num_workers'], create_custom_env)
        env = self._transforme_env(env, self.obs_norm_sd)
        check_env_specs(env)
        rollout = env.rollout(3)
        print("rollout of three steps:", rollout)
        return env
        
        
        
    def _transforme_env(self, env, obs_norm_sd=None):
        if obs_norm_sd is None:
            obs_norm_sd = {"standard_normal": True}
            
        env_ = TransformedEnv(
                env,
                Compose(
                    StepCounter(),
                    DoubleToFloat(
                        in_keys=["observation"],
                    ),
                    ObservationNorm(in_keys=["observation"], **obs_norm_sd),
                ),
                device=self.fenv_config['device'],
            )
        return env_