Define network architecture based on state_dict

Let’s say I have a simple network:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(10, 5, bias=False)
        self.fc2 = nn.Linear(5, 2, bias=False)

    def forward(self, x):
        x = x.view(-1, 10)
        x = F.relu(self.fc1(x))
        x = self.fc3(x)
        return x

model = Net()

With my code I obtain a state_dict of different dimensions, for example:

fc1.weight
tensor([[ 0.1149,  0.2626, -0.0651,  0.0000, -0.0510,  0.0335,  0.2863,  0.0000,
         -0.1991,  0.0000],
        [ 0.0166, -0.1621,  0.0535, -0.2953, -0.2285, -0.1630,  0.1995,  0.1854,
         -0.1402, -0.0114]])
fc2.weight
tensor([[ 0.1775,  0.2999],
        [-0.3467, -0.2310]])

This state_dict can not be loaded in model due to the dimension mismatch. I cannot manually change the Net class because different executions may lead to different state_dict.
Is there a way to modify the architecture of the network, while preserving its forward logic and overall structure (type and order of layers), so that I can load this state_dict without manually changing the class?

Thank you in advance.

What you can do is to define a BaseNet which impelments the forward pass

class BaseNet(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
    def forward(self, x):
        x = x.view(-1, 10)
        x = F.relu(self.fc1(x))
        x = self.fc3(x)
        return x

Later on, for each network you can define the network from kwargs inheriting the one which implements the forward Like:

class Net(BaseNet):
    def __init__(self,**kwargs):
        super(Net, self).__init__()
        for key, values in kwargs.items():
            setattr(self,key,nn.Linear(*values))
kwargs = {'fc1':[10,5,False],'fc2':[5,2,False]}
model = Net(**kwargs)

It’s a matter of playing around how to pass arguments to build the network in the init, plenty of ways of doing it.

2 Likes

Just curious, But why are we not able to load the “state_dict” from the same class again?

@JuanFMontesinos Thank you I’ll give it a shot.
@ecdrid Because the “new” state_dict is of different size, for example the original fc1 is a [5,10] tensor while the one in the state_dict is a [2,10] and it would give the following error:

RuntimeError: Error(s) in loading state_dict for Net:
	size mismatch for fc1.weight: copying a param with shape torch.Size([2, 10]) from checkpoint, the shape in current model is torch.Size([5, 10]).
	size mismatch for fc2.weight: copying a param with shape torch.Size([2, 2]) from checkpoint, the shape in current model is torch.Size([2, 5]).

Sorry by copỳ pasting I forgot to say and write that further nets should be inherited from BaseNet. I modified the pseudocode

1 Like

But that would really imply that we have different archs altogether, right? And you are attempting to load the previous state_dict to this new arch by trimming the layer weights, let’s say?
What if your new arch has more neurons than the old one?

Yes, this is indeed a (toy) example of the result of a pruning procedure that removed a series of neurons. The new network will never have more neurons, but I don’t think that matters.
The two architectures (old and pruned) are different but “only” from the number of neurons perspective, they maintain the same number, type and order of layers.

To further explain, what I’d like to do is take this pruned tensors and use them to define a network that follows the original architecture but has different tensors shapes. I hope this clears things up.

1 Like

Thanks; Now i get what you are trying to do :slight_smile: