No gradient found for a parameter in custom Linear class

I have defined a custom linear class named ‘Linear’ (nn.Module) as follows:

import torch
import torch.nn as nn
import torch.nn.functional as F
from prot_map import *

class Linear(nn.Module):
    def __init__(
        self,
        in_features,
        out_features,
        bias=True,
        name=None,
        num_prot = 50,
        C=16
    ):
        super(Linear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.name = name
        self.num_prot = num_prot
        self.C = C

        # Define weight and bias parameters
        self.weight = nn.Parameter(torch.Tensor(out_features, in_features), requires_grad=True)
        self.bias = nn.Parameter(torch.Tensor(out_features), requires_grad=True) if bias else None
        self.prototype = nn.Parameter(torch.Tensor(num_prot, in_features), requires_grad=True)
        nn.init.normal_(self.prototype, mean=0, std=0.01)

    def append_name(self, postfix):
        self.name += postfix

    def forward(self, input):
        # Get dimensions
        batch, N, D = input.shape
        K, D_prot = self.prototype.shape
        assert D == D_prot, "Input and prototype dimensions must match"

        # Split input and prototype across C codespaces
        input_split = input.view(batch, N, self.C, D // self.C)  # Shape: (batch, N, C, D//C)
        prototype_split = self.prototype.view(K, self.C, D // self.C)  # Shape: (K, C, D//C)

        # Initialize lists to store mapped inputs and probabilities across codespaces
        x_map = []
 
        for c in range(self.C):
            
            # Assume input_split[:, :, c, :] has shape (batch, N, D//C) and prototype_split[:, c, :] has shape (K, D//C)
            input_data = input_split[:, :, c, :]  # Shape: (batch, N, D//C)
            prototypes = prototype_split[:, c, :]  # Shape: (K, D//C)

            # Compute distances for each codespace independently
            distances = torch.cdist(input_split[:, :, c, :], prototype_split[:, c, :], p=2) #shape = (batch, N, K)
            epsilon = 1e-8
            prob = F.softmax(1 / (distances + epsilon), dim = -1)  # Shape: (batch, N, K)
            assert prob.shape == (batch, N, K), "Incorrect probability tensor shape"

            # Map input to prototypes for each codespace
            mapped_input = prob @ prototypes  # Shape: (batch, N, D//C)
            x_map.append(mapped_input)

        # Concatenate mapped inputs and prob_matrix from all codespaces
        mapped_input = torch.cat(x_map, dim=-1)  # Shape: (batch, N, D)

        output = F.linear(mapped_input, self.weight, self.bias)  # Shape: (batch, N, out_features)

        return output

I am trying to map the inputs to learnable prototypes as outlined in the code above. However, when I use this Linear class during training, I find that prototype.grad is None by using this code:

for name, module in model.named_modules():
                if isinstance(module, Linear):
                    if module.prototype.grad is not None:
                        print(f"Gradient of prototype parameter in {name}:", module.prototype.grad)
                    else:
                        print(f"No gradient for prototype parameter in {name}!")

Suggest me how to ensure that the gradients are propagated and where am I going wrong

I cannot reproduce any issue and see a valid gradient in .prototype.grad:

lin = Linear(10, 10, C=10)
x = torch.randn(1, 10, 10)

out = lin(x)
out.mean().backward()

print(lin.prototype.grad.abs().sum())
# tensor(40590.8516)

but I also don’t know how you are using this model.