Hi @vmoens,
Many thanks for your quick reply.
I tested what you suggested, but it did not work.
However, I found a kind of workaround as follows. I know it might not be the best way, but seems working.
from gym_floorplan.envs.master_env import SpaceLayoutGym
class EnvMaker:
def __init__(self, fenv_config):
self.fenv_config = fenv_config
self.env_name = fenv_config['env_name']
self._register_my_env(env_config=fenv_config)
def _register_my_env(self, env_name: str = 'SpaceLayoutGym-v0', env_config={}):
gym.envs.register(id='SpaceLayoutGym-v0',
entry_point=self._env_creator,
kwargs={'env_config': env_config})
def _env_creator(self, env_config):
return SpaceLayoutGym(env_config)
def make_dummy(self):
test_env = gym.make(self.env_name)
test_env.device = self.fenv_config['device']
test_env = GymWrapper(test_env)
test_env = self._transforme_env(test_env)
self.obs_norm_sd = self._get_norm_stats(test_env)
return test_env
def make_single(self):
def create_custom_env():
env = gym.make(self.fenv_config['env_name'])
env.device = self.fenv_config['device']
env.batch_size = self.fenv_config['batch_size']
return GymWrapper(env)
env = create_custom_env()
env = self._transforme_env(env, self.obs_norm_sd)
check_env_specs(env)
rollout = env.rollout(3)
print("rollout of three steps:", rollout)
return env
def make_parallel(self):
def create_custom_env():
env = GymWrapper(SpaceLayoutGym(self.fenv_config))
return env
env = ParallelEnv(self.fenv_config['num_workers'], create_custom_env)
env = self._transforme_env(env, self.obs_norm_sd)
check_env_specs(env)
rollout = env.rollout(3)
print("rollout of three steps:", rollout)
return env
def _transforme_env(self, env, obs_norm_sd=None):
if obs_norm_sd is None:
obs_norm_sd = {"standard_normal": True}
env_ = TransformedEnv(
env,
Compose(
StepCounter(),
DoubleToFloat(
in_keys=["observation"],
),
ObservationNorm(in_keys=["observation"], **obs_norm_sd),
),
device=self.fenv_config['device'],
)
return env_