How to copy a torch.nn.Module and assert that the copy was succefull

My code:

    ddpg_agent_actor = centralized_ddpg_agent_actor(num_actions, num_states)
    ddpg_agent_target_actor = copy.deepcopy(ddpg_agent_actor) #assert fails
    ddpg_agent_target_actor = pickle.loads(pickle.dumps(ddpg_agent_actor)) #assert fails
    assert ddpg_agent_actor == ddpg_agent_target_actor

To give some context, I am working on implementing DDPG, and I need to initialize a target actor, which should have the same parameters as the actor.

Here is the torch.nn.Module I am trying to copy.(nothing weird, just a standard torch.nn.Module)

class centralized_ddpg_agent_actor(torch.nn.Module):
    def __init__(self, action_space_size, observation_state_size):
        super().__init__()
        self.linear1 = torch.nn.Linear(observation_state_size, 128)
        self.linear2 = torch.nn.Linear(128, 256)
        self.linear3 = torch.nn.Linear(256, action_space_size)

    def forward(self, state):
        output = torch.tanh(self.linear1(torch.tensor(state, dtype=torch.float32)))
        output = torch.tanh(self.linear2(output))
        output = torch.tanh(self.linear3(output))
        return output

The copy needs to be a deep copy (i.e. the parameters being able to be updated independently).

It is possible I am asserting wrong.

you mean you can not make a deep copy?But i have run your code and it seems that it can work well

import copy
import numpy as np
import torch
import pandas as pd
import pickle
import torch.nn as nn

class centralized_ddpg_agent_actor(torch.nn.Module):
    def __init__(self, action_space_size, observation_state_size):
        super().__init__()
        self.linear1 = torch.nn.Linear(observation_state_size, 128)
        self.linear2 = torch.nn.Linear(128, 256)
        self.linear3 = torch.nn.Linear(256, action_space_size)

    def forward(self, state):
        output = torch.tanh(self.linear1(torch.tensor(state, dtype=torch.float32)))
        output = torch.tanh(self.linear2(output))
        output = torch.tanh(self.linear3(output))
        return output
def init_zeros(m):
    if type(m) == nn.Linear:
        nn.init.zeros_(m.weight)
        nn.init.zeros_(m.bias)

def init_ones(m):
    if type(m) == nn.Linear:
        nn.init.ones_(m.weight)
        nn.init.ones_(m.bias)


num_actions, num_states = 10,2
ddpg_agent_actor = centralized_ddpg_agent_actor(num_actions, num_states)
ddpg_agent_target_actor = copy.deepcopy(ddpg_agent_actor) #assert fails
ddpg_agent_target_actor = pickle.loads(pickle.dumps(ddpg_agent_actor)) #assert fails
print(id(ddpg_agent_target_actor))
print(id(ddpg_agent_actor))

ddpg_agent_actor.apply(init_zeros)
ddpg_agent_target_actor.apply(init_ones)
print(ddpg_agent_actor.linear1.weight.data == ddpg_agent_target_actor.linear1.weight.data) # result is false

assert ddpg_agent_actor == ddpg_agent_target_actor

When I run your code, I get this

$ py test.py
139817036878880
139817036890064
tensor([[False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False],
        [False, False]])
Traceback (most recent call last):
  File "/home/master-andreas/test.py", line 42, in <module>
    assert ddpg_agent_actor == ddpg_agent_target_actor
AssertionError

assertion fails, but torch.nn.Module ID’s are different

yeah,because ddpg_agent_actor != ddpg_agent_target_actor.I think you need to check the way to use assert if you question is why assertion is triggered but deep copy is successful.

I guess my assert statement was wrong

== and != check the IDs, not the contents (like it does in C++)