I am trying to deepcopy a model:
class Model(torch.nn.Module):
def __init__(self, ann_private, rnn_public, policy):
super(Model, self).__init__()
self.ann_private = ann_private
self.rnn_public = rnn_public
self.policy = policy
def forward(self, private_input, state):
private_input = self.ann_private(private_input)
processed_state = self.rnn_public(state)
self.inference_state = torch.cat([processed_state, private_input], dim=-1)
mu, sigma, state_values = self.policy(self.inference_state)
return mu, sigma, state_values
where the policy contains some ‘self.mu’, ‘self.sigma’ etc (which are as I suppose causes of this problem, as before I was just calling them mu, sigma and deepcopy was working, but now I explicitly need self for them).
Any solution?