Understanding Backpropagation In my Module

Hi, I am new to PyTorch and sorry for this really long post. I have written a model, but it is not working as expected. I think there is some problem with my understanding of how gradients flow work in PyTorch. Basically I have implemented clustering using deep learning. I have a module which has an FC layer, the weights of which represent n-dimensional cluster centers. Each cluster represents a class of my data. For example the first 10 neurons represent the first cluster’s center and next 10 represent the second cluster’s center. Input to the FC layer is a constant 1 and its weights (and thus clusters centers) are learned from training data using backpropagation by using two different losses. Input to this module will be a feature vector. The module first encodes the feature vector using some FC layers which are shown below:

        #DML
        self.DML = nn.Sequential(OrderedDict([
                        ('fc1', nn.Linear(self.in_Dims, 1024, bias=True)),
                        ('relu1',nn.ReLU()),
                        ('bn1',nn.BatchNorm1d(1024)),
                        ('fc2',nn.Linear(1024, 1024, bias=True)),
                        ('relu2',nn.ReLU()),
                        ('fc3',nn.Linear(1024, self.e, bias=True))
                    ]))

Then the distance of this encoded feature vector (called embedding) from each cluster center is computed. To store these distances, I am creating a new tensor of zeros and then filling it. Does this break the graph? The code is as follows:

    def compute_Dij(self, embeddings, clusters):
        clusters_ = clusters.view(self.n_classes, -1).repeat(embeddings.shape[0],1)
        embeddings_ = embeddings.repeat(1,self.n_reps).repeat(1,self.n_classes).view(-1,embeddings.shape[1]*self.n_reps)

        #sanity check
        assert clusters_.shape == embeddings_.shape, 'embeddings and clusters shape do not match'
        assert (embeddings_[0] == embeddings_[self.n_classes-1]).all(), 'embeddings not reshaped properly'
        assert (clusters_[0] == clusters_[self.n_classes]).all(), 'clusters not reshaped properly'

        squared_diff = (clusters_ - embeddings_)**2
        
        Dij = torch.zeros(embeddings_.shape[0],self.n_reps)
        for idx,chunk in enumerate(squared_diff.chunk(self.n_reps,1)):
            Dij[:,idx] = torch.sqrt(torch.sum(chunk, dim=1))

        return Dij

These distances are then used to calculate the probabilities of each sample belonging to each cluster. The code is as follows:

    def compute_posterioirs(self, Dij):
        Pij = torch.exp(-(Dij)/(2*self.sigma**2))
        return Pij

These calculated probabilities are then used to calculate CrossEntropy loss. I have also used another loss used called Triplet loss. In my function for triplet loss, I make a new tensor of zeros and fill it using the distances calculated above. Will this loss keep the track of the operations? The code is as follows:

    def compute_tripletLoss(self, Dij, gt):
        Dij_classWise = torch.min(Dij, 1)[0].view(-1, self.n_classes)
        Dij_sorted, indices = torch.sort(Dij_classWise, 1)
        indices = indices.clone().detach().cpu().numpy()
        indices_gt = gt.clone().detach().cpu().numpy()

        min_untrueClass_indices = [x[0] if x[0]!=gt[idx] else x[1] for idx, x in enumerate(indices)]
        
        loss = torch.tensor(0, dtype=torch.float32).to('cuda')
        for i in range(gt.shape[0]):
            l = torch.tensor(Dij_classWise[i,indices_gt[i]] - Dij_classWise[i,min_untrueClass_indices[i]] + self.alpha)
            loss += nn.functional.relu(l)

        loss = loss/gt.shape[0]
        return loss

The overall code is here:

import torch
from torch import nn
from torch.autograd import Function, Variable
import torch.nn.functional as F

from collections import OrderedDict

from sklearn.metrics import confusion_matrix
import numpy as np

class cluster_module(nn.Module):
    """ Clustering Module """
    def __init__(self, n_classes, n_reps, ndims_E):
        super(cluster_module, self).__init__()
        self.n_classes = n_classes
        self.n_reps = n_reps
        self.in_Dims = ndims_E
        self.e = 256

        #DML
        self.DML = nn.Sequential(OrderedDict([
                        ('fc1', nn.Linear(self.in_Dims, 1024, bias=True)),
                        ('relu1',nn.ReLU()),
                        ('bn1',nn.BatchNorm1d(1024)),
                        ('fc2',nn.Linear(1024, 1024, bias=True)),
                        ('relu2',nn.ReLU()),
                        ('fc3',nn.Linear(1024, self.e, bias=True))
                    ]))

        #clustering
        self.fc_const = torch.tensor([1], dtype=torch.float32).cuda()
        self.fc_outDims = self.n_classes * self.n_reps * 256
        self.fc_clusters = nn.Linear(1, self.fc_outDims, bias=True)

        #losses
        self.loss_cls = nn.CrossEntropyLoss()

        #Instance Adaptation

        #other parameters
        self.sigma = 0.5
        self.alpha = 1.0

    def forward(self, pooled_feat, labels_gt,
                    sdt_pooled_feat, sdt_labels_gt, 
                            tgt_pooled_feat, training_stage, training):
        # self.training = training
        self.training_stage = training_stage

        #feed constant 1 to clusters FC layer
        fc_out = self.fc_clusters(self.fc_const).cuda()

        #ship to cuda
        labels_gt = labels_gt.cuda()
        sdt_labels_gt = sdt_labels_gt.cuda()

        #get embeddings
        dml_out = self.DML(pooled_feat).cuda()
        sdt_dml_out = self.DML(sdt_pooled_feat).cuda()
        
        #get distances matrix
        Dij = self.compute_Dij(dml_out, fc_out).cuda()
        sdt_Dij = self.compute_Dij(sdt_dml_out, fc_out).cuda()

        #get posteriors matrix
        posteriors = self.compute_posterioirs(Dij).cuda()
        sdt_posteriors = self.compute_posterioirs(sdt_Dij).cuda()

        # labels_prd = self.compute_labels(posteriors).cuda()
        
        #per class probability
        class_probs = self.compute_classProbs(posteriors)
        sdt_class_probs = self.compute_classProbs(sdt_posteriors)

        #calculate classification and triplet losses
        if training:
            loss_cls = self.compute_loss_cls(class_probs, labels_gt)
            sdt_loss_cls = self.compute_loss_cls(sdt_class_probs, sdt_labels_gt)

            loss_triplet = self.compute_tripletLoss(Dij, labels_gt)
            sdt_loss_triplet = self.compute_tripletLoss(sdt_Dij, sdt_labels_gt)
        else:
            loss_cls = None
            loss_triplet = None
        
        #compute confusion matrices
        conf_mat = self.compute_confMat(labels_gt, class_probs)
        sdt_conf_mat = self.compute_confMat(sdt_labels_gt, sdt_class_probs)

        #-------------------
        # For target domain
        #-------------------
        if self.training_stage == 'post':
            tgt_dml_out = self.DML(tgt_pooled_feat).cuda()
            tgt_Dij = self.compute_Dij(tgt_dml_out, fc_out).cuda()
            tgt_posteriors = self.compute_posterioirs(tgt_Dij).cuda()
            tgt_class_probs = self.compute_classProbs(tgt_posteriors)
            tgt_labels_pred = self.compute_labels(tgt_class_probs)
        else:
            tgt_labels_pred = None
            tgt_dml_out = None
            tgt_class_probs = None

        #----------------------------------------------------
        # Returns: 
        #   1. source losses
        #   2. domain transfered source losses
        #   3. DML outputs for instance domain adaptation
        #   4. predicted class probs for consistency loss
        #   5. target labels for pseudo labeling
        #   6. cofusion matrix
        #---------------------------------------------------
        return loss_cls, loss_triplet, \
                    sdt_loss_cls, sdt_loss_triplet, \
                        dml_out, sdt_dml_out, tgt_dml_out, \
                            class_probs, sdt_class_probs, tgt_class_probs, \
                                tgt_labels_pred, conf_mat+sdt_conf_mat
    
    def compute_Dij(self, embeddings, clusters):
        clusters_ = clusters.view(self.n_classes, -1).repeat(embeddings.shape[0],1)
        embeddings_ = embeddings.repeat(1,self.n_reps).repeat(1,self.n_classes).view(-1,embeddings.shape[1]*self.n_reps)

        #sanity check
        assert clusters_.shape == embeddings_.shape, 'embeddings and clusters shape do not match'
        assert (embeddings_[0] == embeddings_[self.n_classes-1]).all(), 'embeddings not reshaped properly'
        assert (clusters_[0] == clusters_[self.n_classes]).all(), 'clusters not reshaped properly'

        squared_diff = (clusters_ - embeddings_)**2
        
        Dij = torch.zeros(embeddings_.shape[0],self.n_reps)
        for idx,chunk in enumerate(squared_diff.chunk(self.n_reps,1)):
            Dij[:,idx] = torch.sqrt(torch.sum(chunk, dim=1))

        return Dij

    def compute_posterioirs(self, Dij):
        Pij = torch.exp(-(Dij)/(2*self.sigma**2))
        return Pij
    
    def compute_classProbs(self, Pij):
        class_probs = torch.max(Pij, 1)[0].view(-1, self.n_classes)

        #set background prob as suggested by paper
        # class_probs[:,0] = 1 - torch.max(class_probs[:,1:], 1)[0]
        return class_probs

    def compute_labels(self, posteriors):
        labels = torch.max(posteriors, 1)[1]
        return labels

    def compute_loss_cls(self,labels_prd,  gt):
        # gt_ = torch.unsqueeze(gt, 0).type(torch.float32)
        # gt_ = torch.unsqueeze(gt, 0)
        # labels_prd_ = torch.unsqueeze(labels_prd, 0)
        loss = self.loss_cls(labels_prd, gt)
        return loss

    def compute_tripletLoss(self, Dij, gt):
        Dij_classWise = torch.min(Dij, 1)[0].view(-1, self.n_classes)
        Dij_sorted, indices = torch.sort(Dij_classWise, 1)
        indices = indices.clone().detach().cpu().numpy()
        indices_gt = gt.clone().detach().cpu().numpy()

        min_untrueClass_indices = [x[0] if x[0]!=gt[idx] else x[1] for idx, x in enumerate(indices)]
        
        loss = torch.tensor(0, dtype=torch.float32).to('cuda')
        for i in range(gt.shape[0]):
            l = torch.tensor(Dij_classWise[i,indices_gt[i]] - Dij_classWise[i,min_untrueClass_indices[i]] + self.alpha)
            loss += nn.functional.relu(l)

        loss = loss/gt.shape[0]
        return loss

    def compute_confMat(self, labels_gt, labels_pred):
        y_prd = labels_pred.clone()
        y_prd = torch.argmax(y_prd, 1).detach().cpu().numpy()
        y_true = labels_gt.clone().detach().cpu().numpy()
        labels = np.arange(self.n_classes)
        mat = confusion_matrix(y_true, y_prd, labels=labels)
        return mat

Can anyone have a look at it?

Hi,

The autograd will keep track of all ops in pytorch with the following exceptions:

  • If you detach() a Tensor, because this function means: “Give me a new Tensor without tracking gradients”
  • If you perform ops in a with torch.no_grad() block for the same reason as above
  • If you perform non-pytorch ops
  • If you perform non-differentiable pytorch ops (like argmax).