How to make ParrallelEnv from a custom env that requires env_config?

Hi @vmoens,
Many thanks for your quick reply.

I tested what you suggested, but it did not work.

However, I found a kind of workaround as follows. I know it might not be the best way, but seems working.

from gym_floorplan.envs.master_env import SpaceLayoutGym

class EnvMaker:
    def __init__(self, fenv_config):
        self.fenv_config = fenv_config
        self.env_name = fenv_config['env_name']
        
        self._register_my_env(env_config=fenv_config)
            
    
    def _register_my_env(self, env_name: str = 'SpaceLayoutGym-v0', env_config={}):
        gym.envs.register(id='SpaceLayoutGym-v0',
                          entry_point=self._env_creator,
                          kwargs={'env_config': env_config})
    
     
        
    def _env_creator(self, env_config):
        return SpaceLayoutGym(env_config)
    
    
    
    def make_dummy(self):
        test_env = gym.make(self.env_name)
        test_env.device = self.fenv_config['device']
        test_env = GymWrapper(test_env)
        test_env = self._transforme_env(test_env)
        self.obs_norm_sd = self._get_norm_stats(test_env)
        return test_env
    
    
    
    def make_single(self):
        def create_custom_env():
            env = gym.make(self.fenv_config['env_name'])
            env.device = self.fenv_config['device']
            env.batch_size = self.fenv_config['batch_size']
            return GymWrapper(env)
        env = create_custom_env()
        env = self._transforme_env(env, self.obs_norm_sd)
        check_env_specs(env)
        rollout = env.rollout(3)
        print("rollout of three steps:", rollout)
        return env
    
    
    
    def make_parallel(self):
        def create_custom_env():
            env = GymWrapper(SpaceLayoutGym(self.fenv_config))
            return env
        env = ParallelEnv(self.fenv_config['num_workers'], create_custom_env)
        env = self._transforme_env(env, self.obs_norm_sd)
        check_env_specs(env)
        rollout = env.rollout(3)
        print("rollout of three steps:", rollout)
        return env
        
        
        
    def _transforme_env(self, env, obs_norm_sd=None):
        if obs_norm_sd is None:
            obs_norm_sd = {"standard_normal": True}
            
        env_ = TransformedEnv(
                env,
                Compose(
                    StepCounter(),
                    DoubleToFloat(
                        in_keys=["observation"],
                    ),
                    ObservationNorm(in_keys=["observation"], **obs_norm_sd),
                ),
                device=self.fenv_config['device'],
            )
        return env_