Assign new weights to a layer, PyTorch

For a toy CNN architecture:

class LeNet5(nn.Module):
    def __init__(self):
    # def __init__(self, beta = 1.0):
        super().__init__()
        
        # Trainable parameter for swish activation function-
        # self.beta = nn.Parameter(torch.tensor(beta, requires_grad = True))
        
        self.conv1 = nn.Conv2d(
            in_channels = 1, out_channels = 6, 
            kernel_size = 5, stride = 1,
            padding = 0, bias = False 
        )
        self.bn1 = nn.BatchNorm2d(num_features = 6)
        self.pool = nn.MaxPool2d(kernel_size = 2, stride = 2)
        self.conv2 = nn.Conv2d(
            in_channels = 6, out_channels = 16,
            kernel_size = 5, stride = 1,
            padding = 0, bias = False
        )
        self.bn2 = nn.BatchNorm2d(num_features = 16)
        self.fc1 = nn.Linear(
            in_features = 256, out_features = 120,
            bias = False
        )
        self.bn3 = nn.BatchNorm1d(num_features = 120)
        self.fc2 = nn.Linear(
            in_features = 120, out_features = 84,
            bias = False
        )
        self.bn4 = nn.BatchNorm1d(num_features = 84)
        self.fc3 = nn.Linear(
            in_features = 84, out_features = 10,
            bias = True
        )
        
        # self.initialize_weights()

        
    def initialize_weights(self):
        for m in self.modules():
            # print(m)
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
                
                # Do not initialize bias (due to batchnorm)-
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
                
            elif isinstance(m, nn.BatchNorm2d):
                # Standard initialization for batch normalization-
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

            elif isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight)
                nn.init.constant_(m.bias, 0)
    
    
    def swish_fn(self, x):
        return x * torch.sigmoid(x * self.beta)

    
    def forward(self, x):
        
        x = nn.SiLU()(self.pool(self.bn1(self.conv1(x))))
        x = nn.SiLU()(self.pool(self.bn2(self.conv2(x))))
        x = x.view(-1, 256)
        x = nn.SiLU()(self.bn3(self.fc1(x)))
        x = nn.SiLU()(self.bn4(self.fc2(x)))
        x = self.fc3(x)
        return x


model = LeNet5()

To access the named parameters/layers:

for layer_name, params in model.named_parameters():
    print(f"{layer_name} has {params.size()} params")
'''
conv1.weight has torch.Size([6, 1, 5, 5]) params
bn1.weight has torch.Size([6]) params
bn1.bias has torch.Size([6]) params
conv2.weight has torch.Size([16, 6, 5, 5]) params
bn2.weight has torch.Size([16]) params
bn2.bias has torch.Size([16]) params
fc1.weight has torch.Size([120, 256]) params
bn3.weight has torch.Size([120]) params
bn3.bias has torch.Size([120]) params
fc2.weight has torch.Size([84, 120]) params
bn4.weight has torch.Size([84]) params
bn4.bias has torch.Size([84]) params
fc3.weight has torch.Size([10, 84]) params
fc3.bias has torch.Size([10]) params
'''

# Or-
[x[0] for x in model.named_parameters()]
"""
['conv1.weight',
 'bn1.weight',
 'bn1.bias',
 'conv2.weight',
 'bn2.weight',
 'bn2.bias',
 'fc1.weight',
 'bn3.weight',
 'bn3.bias',
 'fc2.weight',
 'bn4.weight',
 'bn4.bias',
 'fc3.weight',
 'fc3.bias']
"""

[x[0] for x in model.named_parameters()][3]
# 'conv2.weight'

layer_name = [x[0] for x in model.named_parameters()][3]

layer_name
# 'conv2.weight'

To assign new values/weights to say ‘conv2.weight’, I tried:

model.state_dict()[layer_name] = torch.randn(16, 6, 5, 5)

But this doesn’t work, because on checking some values before and after assignment, I see no change

model.state_dict()[layer_name][0, 0, :3, :3]

If I try to assign new values to ‘conv2.weight’ layer with

layer_name = 'conv2.weight'

# Assign new weights-
model.state_dict()[layer_name] = torch.randn(16, 6, 5, 5)

# Or-
with torch.no_grad():
    model.state_dict()[layer_name] = torch.randn(16, 6, 5, 5)

It doesn’t work. How can I assign new values to ‘conv2.weight’?

You can either directly manipulate the parameter or you would need to load the manipulated state_dict since you are currently working on a temporary object:

model = torchvision.models.resnet18()

with torch.no_grad():
    model.conv1.weight.copy_(torch.ones_like(model.conv1.weight))

print(model.conv1.weight)
# Parameter containing:
# tensor([[[[1., 1., 1.,  ..., 1., 1., 1.],
#           [1., 1., 1.,  ..., 1., 1., 1.],
#           [1., 1., 1.,  ..., 1., 1., 1.],
#           ...,

sd = model.state_dict()
sd["conv1.weight"] = torch.ones_like(model.conv1.weight) * 2
model.load_state_dict(sd)
print(model.conv1.weight)
# Parameter containing:
# tensor([[[[2., 2., 2.,  ..., 2., 2., 2.],
#           [2., 2., 2.,  ..., 2., 2., 2.],
#           [2., 2., 2.,  ..., 2., 2., 2.],
#           ...,
1 Like