Multiplication matrix weights update issue

I am trying to define a matrix to be multiplied with a vector where the weights of this matrix will be learned through the training, the issue that the weights of this matrix still unchanged during the training here my code:

class MyNetwork(nn.Module):
    def __init__(self, device):
        super(MyNetwork, self).__init__()
        self.model = torchvision.models.densenet161(pretrained=True)              
        self.model.classifier = nn.Sequential(
                nn.Dropout(0.5),
                nn.Linear(2208, 256), 
                nn.ReLU(),
                nn.Dropout(0.5)
        )
        self.fc1 = nn.Linear(256, 23)
        self.fc2 = nn.Linear(23, 10)
        
        self.w = nn.Parameter(torch.randn(10,23).to(device), requires_grad = True)


    def forward(self, input):
        output1 = self.fc1(F.relu(self.model(input1))
        feaco1 = self.fc2(F.relu(self.drop(output1)))
        feaco2 = torch.nn.functional.one_hot(torch.argmax(feaco1, dim=1), num_classes=10).float()        
        shap1 = feaco2.matmul(self.w)
        output1 = torch.add(shap1, output1)        
 
       return output1, shap1
----------------------------------------------------------------------------
device = torch.device("cuda:2") 
model = MyNetwork(device).to(device) 
optimizer = optim.Adam(model.parameters(), lr = lr)

optimizer.zero_grad()
loss.backward()
optimizer.step() 

Did I miss anything here

Can you share how you are computing the loss (e.g., what loss function is being used and what the inputs are) here?

Hi thanks for replying
Actually my Network is different from I posted, which is Siamese Architecture and the loss function have 4 parts
Actually the loss was working correctly before adding the multiplication matrix, where its weights are not updated during training.

loss = loss1 + loss2 + loss3+ 0.3*loss4

loss2 is calculated from the shap1 output using focal loss function

Is there a reason you cannot use another linear layer (self.matmul = nn.Linear(10, 23)) rather than defining the parameters manually?

Actually, I tried to replace it with conv layer like this:

self.w = nn.Conv1d(in_channels=10, out_channels=23, kernel_size=1, stride = 1)

But the weights still unchanged

Can you post what the actual forward function looks like here?

I still see shap1 = feaco2.matmul(self.w) rather than a layer being used. What happens when you replace the feaco2.matmul(self.w) with something like self.fc3(feaco2)?

However, if you are referring to the rest of the model weights not changing, I think this is because torch.nn.functional.one_hot(torch.argmax(feaco1, dim=1), num_classes=10).float() is not differentiable. You might want to consider replacing this with a softmax instead which is differentiable.

Hi, thanks for your reply
I got the same result when I used:
w as a conv layer
self.w = nn.Conv1d(in_channels=10, out_channels=23, kernel_size=1, stride = 1)

I will use the out features instead of the one hot vector and update you
Best regards

I couldn’t reproduce the issue on my setup; note that you cannot use ll to compare the weights after each iteration as the underlying storage does not change and each of the new elements of the list actually point to the same memory. The changes to the weights also seem to be very small depending on the learning rate.
Code I used:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

ll = list()

class SiameseNetwork(nn.Module):
    def __init__(self, device):
        super(SiameseNetwork, self).__init__()
        self.model = torchvision.models.densenet161(pretrained=True)
        self.model.classifier = nn.Linear(2208, 23)
        #self.name = './ModelsCh/Densent161_focal_best.pt'
        #self.model.load_state_dict(torch.load(self.name))
        self.model.classifier = nn.Sequential(
                nn.Dropout(0.5),
                nn.Linear(2208, 256),
                nn.ReLU(),
                nn.Dropout(0.5)
        )
        self.fc1 = nn.Linear(256, 23)
        self.fc2 = nn.Linear(46, 10)

        self.ptsigmoid = nn.Sigmoid()
        self.drop = nn.Dropout(0.5)
        self.w = nn.Linear(10, 23)


    def forward_once(self, x):
        output = self.model(x)
        output = self.fc1(F.relu(output))
        if self.training:
           outpu2 =  1
        else:
           outpu2 = self.fc2(F.relu(torch.cat((output, output), dim=1)))
           output = torch.add(self.w(outpu2), output)
        return output, outpu2

    def forward(self, input1, input2):
        output1, out1 = self.forward_once(input1)
        output2, out2 = self.forward_once(input2)
        feaco =  torch.cat((output1, output2), dim=1)
        feaco1 = self.fc2(F.relu(self.drop(feaco)))

        shap1 = self.w(feaco1)
        ll.append(self.w.weight.data)
        print(torch.sum(self.w.weight))
        output1 = torch.add(shap1, output1)
        output2 =  torch.add(shap1, output2)
        return output1, output2, feaco1

net = SiameseNetwork('cuda')
net = net.cuda()
in1 = torch.randn(1, 3, 224, 224, device='cuda')
in2 = torch.randn(1, 3, 224, 224, device='cuda')
target = torch.tensor([3], device='cuda', dtype=torch.long)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), 1e-1,
                momentum=0.9,
                weight_decay=1e-4)
for i in range(0, 10):
    _, out, _ = net(in1, in2)
    loss = criterion(out, target)
    optimizer.zero_grad()
    loss.backward()
    print(loss.item())
    optimizer.step()

print([data.storage().data_ptr() for data in ll])

Output:

tensor(1.6240, device='cuda:0', grad_fn=<SumBackward0>)
3.283719062805176
tensor(1.6240, device='cuda:0', grad_fn=<SumBackward0>)
1.6470584869384766
tensor(1.6240, device='cuda:0', grad_fn=<SumBackward0>)
0.0
tensor(1.6239, device='cuda:0', grad_fn=<SumBackward0>)
0.0
tensor(1.6239, device='cuda:0', grad_fn=<SumBackward0>)
0.0
tensor(1.6238, device='cuda:0', grad_fn=<SumBackward0>)
0.0
tensor(1.6237, device='cuda:0', grad_fn=<SumBackward0>)
0.0
tensor(1.6236, device='cuda:0', grad_fn=<SumBackward0>)
0.0
tensor(1.6235, device='cuda:0', grad_fn=<SumBackward0>)
0.0
tensor(1.6234, device='cuda:0', grad_fn=<SumBackward0>)
0.0
[139980355652608, 139980355652608, 139980355652608, 139980355652608, 139980355652608, 139980355652608, 139980355652608, 139980355652608, 139980355652608, 139980355652608]

Thanks

The weights are updating

Best Regards