None gradients right after parameter initialization

Hi there,

I’ve seen similar posts, but these haven’t addressed the following issue.

I’m experiencing None gradients in a custom layer I have written, which is used as a part of a bigger model. I have checked the gradients of the parameters right after initialization (with pycharm’s debugger), and also during the forward pass, and in both cases the .requires_grad=True, but .grad=None.

I’m new to writing custom layers, so I apologize in advance if this question is trivial.
Should more information be needed to answer, please let me know.

The layer I have written:

import numpy as np
import torch
import torch.nn.functional as F

class HardSelfAttentionLayer(torch.nn.Module):
    def __init__(self, Q_shape, K_shape):
        super().__init__()
        assert len(Q_shape) == 3, f'Unsupported Q tensor dim {len(Q_shape)}'
        assert len(K_shape) == 3, f'Unsupported K tensor dim {len(K_shape)}'
        assert Q_shape[1:] == K_shape[1:], f'1st and 2nd Dimensions must match'
        self.Q_shape = Q_shape
        self.K_shape = K_shape
        self.k = self.K_shape[0]
        self.q = self.Q_shape[0]
        self.n = self.Q_shape[1] * self.Q_shape[2]

        self.w1 = torch.nn.Parameter(torch.randn(size=[self.n],
                                                 requires_grad=True))
        self.w2 = torch.nn.Parameter(torch.randn(size=[self.q],
                                                 requires_grad=True))
        self.att_th = torch.nn.Parameter(torch.tensor(0.5))
        self.sigmoid = torch.nn.Sigmoid()

    def mats2vecs(self, inp):
        return inp.transpose(1, 2).flatten(start_dim=2, end_dim=3)

    def vecs2mats(self, inp):
        return torch.stack(inp.split(self.K_shape[1], dim=2), dim=1)

    def forward(self, Q, K):
        """
        In the forward function we accept a Tensor of input data and we must return
        a Tensor of output data. We can use Modules defined in the constructor as
        well as arbitrary operators on Tensors.
        """
        # we first convert the 3d tensors to 2d tensors, by flattening
        Q = self.mats2vecs(Q)  # (BS, self.q, self.n)
        K = self.mats2vecs(K)  # (BS, self.k, self.n)

        # we create D, a diagonal matrix with w1 values
        D = self.w1.diag_embed()  # (self.n, self.n)

        # we multiply Q*D*K^T
        QDK = Q.matmul(D).matmul(K.transpose(1, 2))  # (BS, self.q, self.k)

        # we normalize QDK to obtain the Similarity matrix
        S = QDK.div(np.sqrt(self.n))  # (BS, self.q, self.k)

        # we apply softmax activation
        act_S = F.relu(S)  # (BS, self.q, self.k)

        # we calculate the Importance matrix
        I = (act_S.transpose(1, 2)).matmul(self.w2).div(np.sqrt(self.q))
        # (BS, self.k)

        # we apply the softmax activation and reshape
        w3 = self.sigmoid(I).reshape(-1, self.k, 1)  # (BS, self.k)

        # we attend to w3, by multiplying
        K[torch.broadcast_tensors(w3 < self.att_th, K)[0]] = -1.0
        # (BS, self.k, self.n)

        return self.vecs2mats(K)

This is expected, since no gradient are calculated after the parameters are initialized or after only the forward pass was executed.
The gradients will be calculated in the backward() call and you should see valid values in the .grad attributes afterwards.

2 Likes

Thanks for the reply, I was not aware of that. However, even when I check the .grad during forward, after the first time, it appears that most of my parameters have grad=torch.zeros(some_size), and the parameters which belong to the above method have grad=None. These parameters do not change during training, unlike the other parameters of my model.

I’m using pytorch lightning, so I implemented a forward function, attached here:

    def forward(self, image, reads_metadata, metadata, maternal_labels,
                paternal_labels):
        cfDNA_reads, parental_reads = torch.split(image, 3, dim=1)

        attention_output = self.soft_self_attention_layer(
            K=cfDNA_reads,
            Q=torch.cat([cfDNA_reads, parental_reads], dim=2)
        )

        cnn_input = torch.cat([attention_output, parental_reads], dim=1)

        x1 = self.cnn(cnn_input)

        x2 = metadata

        bayesian_logits = self.get_bayesian_predictions(
            maternal_labels=maternal_labels,
            paternal_labels=paternal_labels,
            reads_metadata=reads_metadata,
            cffdna_ratio_mean=\
            metadata[:, POS_META_FEATURES.index(CFFDNA_RATIO_MEAN)]
        )

        x = torch.cat((x1, x2, bayesian_logits.cuda()), dim=1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x, bayesian_logits

The zeros would come from zeroing out the gradients via optimizer.zero_grad(). Since your custom module has still None gradients in its parameters, I guess you might have broken the computation graph somehow.
Could you post an executable code snippet using the custom modules with random input tensors, which would show this behavior?

1 Like

Thanks for the offer!
I made it as small and simple as I could, though it still ended up being ~200 loc.
The first ~80 loc are the custom layer, following the ~80 loc of the lightning module, and then it’s just the random dataloader and main training loop via pytorch lightning.

The lines to debug are just 115-127, when the forward function is called. I’ve checked that the above effect takes place there.

################################################################################
### CUSTOM LAYER ###############################################################
################################################################################
import torch
import numpy as np
import torch.nn.functional as F
N = [200, 221, 14]
split_channel = 7

class HardSelfAttentionLayer(torch.nn.Module):
    def __init__(self, Q_shape, K_shape):
        """
        In the constructor we instantiate four parameters and assign them as
        member parameters.
        """
        super().__init__()
        assert len(Q_shape) == 3, f'Unsupported Q tensor dim {len(Q_shape)}'
        assert len(K_shape) == 3, f'Unsupported K tensor dim {len(K_shape)}'
        assert Q_shape[1:] == K_shape[1:], f'1st and 2nd Dimensions must match'
        self.Q_shape = Q_shape
        self.K_shape = K_shape
        self.k = self.K_shape[0]
        self.q = self.Q_shape[0]
        self.n = self.Q_shape[1] * self.Q_shape[2]

        self.w1 = torch.nn.Parameter(torch.randn(size=[self.n],
                                                 requires_grad=True))
        self.w2 = torch.nn.Parameter(torch.randn(size=[self.q],
                                                 requires_grad=True))
        self.att_th = torch.nn.Parameter(torch.tensor(0.5))

        self.sigmoid = torch.nn.Sigmoid()

    def mats2vecs(self, inp):
        return inp.transpose(1, 2).flatten(start_dim=2, end_dim=3)

    def vecs2mats(self, inp):
        return torch.stack(inp.split(self.K_shape[1], dim=2), dim=1)

    def forward(self, Q, K):
        """
        In the forward function we accept a Tensor of input data and we must return
        a Tensor of output data. We can use Modules defined in the constructor as
        well as arbitrary operators on Tensors.
        """
        # we first convert the 3d tensors to 2d tensors, by flattening
        Q = self.mats2vecs(Q)  # (BS, self.q, self.n)
        K = self.mats2vecs(K)  # (BS, self.k, self.n)

        # we create D, a diagonal matrix with w1 values
        D = self.w1.diag_embed()  # (self.n, self.n)

        # we multiply Q*D*K^T
        QDK = Q.matmul(D).matmul(K.transpose(1, 2))  # (BS, self.q, self.k)

        # we calculate the Queries norms
        Q_norms = Q.norm(dim=2).reshape(-1, self.q, 1)  # (BS, self.q, 1)

        # we normalize QDK to obtain the Similarity matrix
        S = QDK.div(np.sqrt(self.n))  # (BS, self.q, self.k)

        act_S = F.relu(S)  # (BS, self.q, self.k)

        # we calculate the Importance matrix
        I = (act_S.transpose(1, 2)).matmul(self.w2).div(np.sqrt(self.q))
        # (BS, self.k)

        w3 = self.sigmoid(I).reshape(-1, self.k, 1)  # (BS, self.k)

        # we attend to w3, by multiplying
        K[torch.broadcast_tensors(w3 < self.att_th, K)[0]] = -1.0
        # (BS, self.k, self.n)

        return self.vecs2mats(K)


################################################################################
### PYTORCH LIGHTNING MODULE ###################################################
################################################################################

import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import seaborn as sns
import matplotlib.pyplot as plt
import pytorch_lightning as pl
import torch.optim as optim
from torch.autograd import Variable
from pytorch_lightning.metrics.functional import accuracy
import torch.nn as nn
import torch.nn.functional as F
import torch


class HoobariLightningModule(pl.LightningModule):
    """ A class that implements the pytorch_lightning.LightningModule
    interface, for the purpose of training Deep Hoobari models. """

    def __init__(self):
        super(HoobariLightningModule, self).__init__()

        # a loss function
        self.loss_func = nn.CrossEntropyLoss()

        fc1_input_size = 265200
        fc1_outputs_size = 64
        self.fc1 = nn.Linear(fc1_input_size, fc1_outputs_size)
        self.fc2 = nn.Linear(fc1_outputs_size, 3)

        # we define a hard self attention layer
        self.hard_self_attention_layer = HardSelfAttentionLayer(
            Q_shape=(400, 221, 3), K_shape=(200, 221, 3))


    def forward(self, image):
        cfDNA_reads, parental_reads = torch.split(image, 3, dim=1)

        attention_output = self.hard_self_attention_layer(
            K=cfDNA_reads,
            Q=torch.cat([cfDNA_reads, parental_reads], dim=2)
        )
        cnn_input = torch.cat([attention_output, parental_reads], dim=1)

        flattened_input = torch.flatten(cnn_input, start_dim=1)
        x1 = F.relu(self.fc1(flattened_input))
        x2 = self.fc2(x1)
        return x2


    def configure_optimizers(self):
        """ returns an optimizer for the backpropagation process"""
        return optim.Adam(
            self.parameters(),
            lr=4e-5,
            weight_decay=0.04,
            amsgrad=False,
        )

    def make_step(self, batch, batch_index):
        # unpack batch
        images, labels = batch

        # forward propagation
        logits = self.forward(image=images)

        # loss and acc calculation
        loss = self.loss_func(logits, labels)
        acc = pl.metrics.functional.accuracy(logits, labels)

        return dict(loss=loss, acc=acc)

    def training_step(self, train_batch, batch_index):
        results = self.make_step(train_batch, batch_index)
        return results['loss']

################################################################################
### PYTORCH RANDOM DATALOADER ##################################################
################################################################################

class RandomDataset(torch.utils.data.Dataset):
    def __init__(self):
        super(RandomDataset, self).__init__()

    def __len__(self):
        return 200

    def __getitem__(self, index):
        X = torch.randn(size=[6, 200, 221])
        y = torch.randint(low=0, high=2, size=[1])

        return X, y[0]

train_dataloader = torch.utils.data.DataLoader(RandomDataset(), batch_size=10)

################################################################################
### PYTORCH LIGHTNING MAIN #####################################################
################################################################################

if __name__ == '__main__':
    lightning_module = HoobariLightningModule()
    trainer = pl.Trainer()
    trainer.fit(lightning_module, train_dataloader=train_dataloader)

Thanks for the code.
The issue is created by reusing K and assigning some results to it in:

K[torch.broadcast_tensors(w3 < self.att_th, K)[0]] = -1.0

Since K doesn’t require gradients, the return value of HardSelfAttentionLayer will also not require gradients.
I’m not familiar with your approach, but the index operation won’t be differentiable.