I have the following class. I’d like to save this class in one script and load it in another one. However, I don’t know how I can correctly initialize the attributes of a class when I load it somewhere else?!
class Dataset(nn.Module):
def __init__(
self,
observation_spec,
action_spec,
size,
):
super(Dataset, self).__init__()
self._size = size
obs_shape = list(observation_spec.shape)
obs_type = observation_spec.dtype
action_shape = list(action_spec.shape)
action_type = action_spec.dtype
self._s1 = self._zeros([size] + obs_shape, obs_type)
self._s2 = self._zeros([size] + obs_shape, obs_type)
self._a1 = self._zeros([size] + action_shape, action_type)
self._a2 = self._zeros([size] + action_shape, action_type)
self._discount = self._zeros([size], torch.float32)
self._reward = self._zeros([size], torch.float32)
self._data = Transition(
s1=self._s1, s2=self._s2, a1=self._a1, a2=self._a2,
discount=self._discount, reward=self._reward)
self._current_size = torch.autograd.Variable(torch.tensor(0), requires_grad=False)
self._current_idx = torch.autograd.Variable(torch.tensor(0), requires_grad=False)
self._capacity = torch.autograd.Variable(torch.tensor(self._size))
self._config = collections.OrderedDict(
observation_spec=observation_spec,
action_spec=action_spec,
size=size
@property
def config(self):
return self._config
@property
def data(self):
return self._data
@property
def capacity(self):
return self._size
@property
def size(self):
return self._current_size.numpy()
def _zeros(self, shape, dtype):
"""Create a variable initialized with zeros."""
return torch.autograd.Variable(torch.zeros(shape, dtype = dtype))
#save the model/class
assert data.size == data.capacity
data_ckpt_name = os.path.join(log_dir, 'data_{}.pt'.format(env_name))
torch.save([data.capacity, data.state_dict()], data_ckpt_name)
whole_data_ckpt_name = os.path.join(log_dir, 'data_{}.pth'.format(env_name))
with open( whole_data_ckpt_name, 'wb') as filehandler:
pickle.dump(data, filehandler)
when I tried to load this class based on this answer with its attribute inside another script
dm_env = gym.spec(env_name).make()
env = alf_gym_wrapper.AlfGymWrapper(dm_env)
observation_spec = env.observation_spec()
action_spec = env.action_spec()
# Prepare data.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_ckpt_name = os.path.join(data_file, 'data_{}.pt'.format(env_name))
whole_data_ckpt_name = os.path.join(data_file, 'data_{}.pth'.format(env_name))
data_size, state = torch.load(data_ckpt_name, map_location=device)
full_data = dc.Dataset(observation_spec, action_spec, data_size)
with open(whole_data_ckpt_name, 'wb') as filehandler:
full_data = pickle.load(filehandler)
full_data.load_state_dict( state)
print(f"loaded data : {full_data.size}")
I got this error message
raise proxy.with_traceback(exception.__traceback__) from None
File "/home/lib/python3.8/site-packages/gin/config.py", line 1582, in gin_wrapper
return fn(*new_args, **new_kwargs)
File "train_eval_offline.py", line 79, in train_eval_offline
full_data = pickle.load(filehandler)
io.UnsupportedOperation: read
In call to configurable 'train_eval_offline' (<function train_eval_offline at 0x2b5ceb9dc8b0>)
I am wondering how I can extract the size
attribute of the class from the saved model when I load it, in order to properly initialize this attribute?