I have defined a custom linear class named ‘Linear’ (nn.Module) as follows:
import torch
import torch.nn as nn
import torch.nn.functional as F
from prot_map import *
class Linear(nn.Module):
def __init__(
self,
in_features,
out_features,
bias=True,
name=None,
num_prot = 50,
C=16
):
super(Linear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.name = name
self.num_prot = num_prot
self.C = C
# Define weight and bias parameters
self.weight = nn.Parameter(torch.Tensor(out_features, in_features), requires_grad=True)
self.bias = nn.Parameter(torch.Tensor(out_features), requires_grad=True) if bias else None
self.prototype = nn.Parameter(torch.Tensor(num_prot, in_features), requires_grad=True)
nn.init.normal_(self.prototype, mean=0, std=0.01)
def append_name(self, postfix):
self.name += postfix
def forward(self, input):
# Get dimensions
batch, N, D = input.shape
K, D_prot = self.prototype.shape
assert D == D_prot, "Input and prototype dimensions must match"
# Split input and prototype across C codespaces
input_split = input.view(batch, N, self.C, D // self.C) # Shape: (batch, N, C, D//C)
prototype_split = self.prototype.view(K, self.C, D // self.C) # Shape: (K, C, D//C)
# Initialize lists to store mapped inputs and probabilities across codespaces
x_map = []
for c in range(self.C):
# Assume input_split[:, :, c, :] has shape (batch, N, D//C) and prototype_split[:, c, :] has shape (K, D//C)
input_data = input_split[:, :, c, :] # Shape: (batch, N, D//C)
prototypes = prototype_split[:, c, :] # Shape: (K, D//C)
# Compute distances for each codespace independently
distances = torch.cdist(input_split[:, :, c, :], prototype_split[:, c, :], p=2) #shape = (batch, N, K)
epsilon = 1e-8
prob = F.softmax(1 / (distances + epsilon), dim = -1) # Shape: (batch, N, K)
assert prob.shape == (batch, N, K), "Incorrect probability tensor shape"
# Map input to prototypes for each codespace
mapped_input = prob @ prototypes # Shape: (batch, N, D//C)
x_map.append(mapped_input)
# Concatenate mapped inputs and prob_matrix from all codespaces
mapped_input = torch.cat(x_map, dim=-1) # Shape: (batch, N, D)
output = F.linear(mapped_input, self.weight, self.bias) # Shape: (batch, N, out_features)
return output
I am trying to map the inputs to learnable prototypes as outlined in the code above. However, when I use this Linear class during training, I find that prototype.grad is None by using this code:
for name, module in model.named_modules():
if isinstance(module, Linear):
if module.prototype.grad is not None:
print(f"Gradient of prototype parameter in {name}:", module.prototype.grad)
else:
print(f"No gradient for prototype parameter in {name}!")
Suggest me how to ensure that the gradients are propagated and where am I going wrong