How is the grad of the classifer calculated, with regard to the loss?

I am indexing the model classifier logits so I can compute the loss on certain nodes. I assumed that since the loss is defined only over those nodes, then only their connection (weights) in the classifier must be updated and the grad of the other connection should not change. However, when I checked the gradients before optimizer.ste(), I noticed that I have a full gradient matrix of values. I can’t debug my code because I am not sure where the problem stems from. Any help would be appreciated.

This is the loss code:

class ClassificationLossVI(nn.Module):
    def __init__(self, args, topk=3):
        super(ClassificationLossVI, self).__init__()
        self._topk = tuple(range(1, topk+1))
        self.label_trick = args.label_trick

    def forward(self, output_dict, target_dict):
        samples = 64
        prediction_mean = output_dict['prediction_mean'].unsqueeze(dim=2).expand(-1, -1, samples)
        prediction_variance = output_dict['prediction_variance'].unsqueeze(dim=2).expand(-1, -1, samples)

        target = target_dict['target1'] # this is the batch target
        #tensor([101, 150, 133,  40,  40, 133, 129,  57,  40,  40], device='cuda:0')
        target_expanded = target.unsqueeze(dim=1).expand(-1, samples) #torch.Size([10, 64])
        normal_dist = torch.distributions.normal.Normal(torch.zeros_like(prediction_mean), torch.ones_like(prediction_mean))
        
        if self.training:
            losses = {}
            normals =  normal_dist.sample()
            prediction = prediction_mean + torch.sqrt(prediction_variance) * normals  #torch.Size([10, 170, 64])
            
            # -------------------------------------------------------------------------------
            #                                 Labels trick
            # -------------------------------------------------------------------------------
            if self.label_trick is False:
                loss = F.cross_entropy(prediction, target_expanded, reduction='mean')
                kl_div = output_dict['kl_div']
                losses['total_loss'] = loss + kl_div()
            
                with torch.no_grad():
                  p = F.softmax(prediction, dim=1).mean(dim=2)
                  losses['xe'] =  F.cross_entropy(prediction, target_expanded, reduction='mean')
                  acc_k = _accuracy(p, target, topk=self._topk)
                  for acc, k in zip(acc_k, self._topk):
                      losses["top%i" % k] = acc
            else:
            
                task_targets = target_dict['task_labels'][0] #shape: [10, 10]
                ordered_task_targets = torch.unique(task_targets, sorted=True)
                #tensor([ 40,  48,  51,  57,  94, 101, 109, 129, 133, 150])
                
                # Get current batch labels (and sort them for reassignment)
                labels = target.clone().detach() #tensor([101, 150, 133,  40,  40, 133, 129,  57,  40,  40], device='cuda:0')
                #unq_labels = torch.unique(labels, sorted=True) # in an ascending order - tensor([0, 1], device='cuda:0')
                
                # ---------- ToDo: the unq_targets must be relabeld according to the indexes of task_targets ------------
                for t_idx, t in enumerate(ordered_task_targets):
                    labels[labels==t] = t_idx
                '''
                # Assign new labels (0,1 ...)
                for t_idx, t in enumerate(unq_labels):
                    labels[labels == t] = t_idx
                '''      
                # expand the target here 
                labels_expanded = labels.unsqueeze(dim=1).expand(-1, samples)  #torch.Size([batch_size, 64])
  
                #loss_label_trick = F.cross_entropy(prediction[:, unq_labels, :], labels_expanded, reduction='mean')
                # should we use an ordered task_targets
                loss_label_trick = F.cross_entropy(prediction[:, ordered_task_targets, :], labels_expanded, reduction='mean')
                kl_div = output_dict['kl_div']
                losses['total_loss'] = loss_label_trick + kl_div()
                
                with torch.no_grad():
                    # ToInvestigate: should we apply the label trick for caculating the accuracy and the xe as well?
                    p = F.softmax(prediction[:, ordered_task_targets, :], dim=1).mean(dim=2)
                    losses['xe'] =  F.cross_entropy(prediction[:, ordered_task_targets, :], labels_expanded, reduction='mean')
                    acc_k = _accuracy(p, labels, topk=self._topk)
                    for acc, k in zip(acc_k, self._topk):
                        losses["top%i" % k] = acc 

Did you check how the indexed subset of the output is calculated and verified that indeed not all weights are used to compute them?
E.g. take a look at this simple model:

lin1 = nn.Linear(3, 3, bias=False)
lin2 = nn.Linear(3, 3, bias=False)

x = torch.randn(1, 3)

out = lin1(x)
out = lin2(out)

loss = out[:, 1]
loss.backward()

# subset of weights contains valid grads
print(lin2.weight.grad)
# tensor([[-0.0000,  0.0000, -0.0000],
#         [-0.0782,  0.4657, -0.2253],
#         [-0.0000,  0.0000, -0.0000]])

# all weights contain valid grads ass entire weight matrix was used to compute out[1]
print(lin1.weight.grad)
# tensor([[ 0.4778, -0.2856, -0.1076],
#         [ 0.1383, -0.0827, -0.0311],
#         [ 0.0974, -0.0582, -0.0219]])

Here you can see that the first layer gets valid grads for all parameters, while the last one only for a subset.

Thank you for your reply.
I checked the cross_entropy loss and things are correct. The culprit is the KLD() term in my loss function:

losses['total_loss'] = loss + kl_div()

I am training a Bayesian neural network and I have a prior and a variational posterior. KLD() shows the difference between them and works as a regularizer. KLD() derivation contains parameters of each layer including the classifier. Do you think it is a good practice to zero down the grad of those connections that I discarded in the cross_entropy after loss.backward() and before opt.step(). If yes, what is the best way to do this? I have momentum as part of my optimizer and I read in this post that zeroing the grad may not stop weights from updating and the best way to do this is by saving the original weights before opt.step() and then copying them into the tensor.

I must add that I have already tried zeroing the weights but it led to very bad results (when plotting the validation acc and loss) so I think this is not the correct way to do this. Any guidance is very much appreciated.

This is the code that shows my attempt:

        if self._args.freeze_classifier_weights is True:
                
            unq_task_targets = torch.unique(example_dict['task_labels'][0], sorted=True)
            # tensor([ 40,  48,  51,  57,  94, 101, 109, 129, 133, 150])

            #unq_task_targets = torch.unique(example_dict['target1'], sorted=True)
            row_indexes = list(set(self.classes) - set(unq_task_targets.tolist()))
            row_indexes_tensor = torch.tensor(row_indexes)
        
            for i, param in enumerate(self._model_and_loss._model.resnet.fc.parameters()):
                # param.shape? we have 3 values in the parameters list: [torch.size([512]), torch.size([170,512]), torch.size([10])]
                # I think the first one is the multiplicative noise and the last one is the bias and the middle one is the weight matrix
                if i==1: #(we need the weights and not the biases therefore i=0)
                    param.grad[row_indexes_tensor,:] = torch.zeros_like(param.grad[row_indexes_tensor,:])