Loss does not decrease during the training, and the gradients for all mlps are 0

Hi,

I find that my customized loss does not decrease during the training so I checked the gradients for all mlps and they are all zeros all the time. The network does not learn at all. Don’t know which part is wrong. Thanks for any help.
The structure of my network:

import torch
from Compared_layer import Compared_layer

class ComparedGNN(torch.nn.Module):
    def __init__(self, edge_feature_size, num_antenna, num_BS, num_layers, power, noise):
        super(ComparedGNN, self).__init__()

        # if you have cuda
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.to(self.device)  # move the entire model to cuda

        # initialize the first connection layer
        self.layer1 = Compared_layer(edge_feature_size, num_antenna, num_BS, power, noise)

        #initialize the middle layers
        self.middle_layers = []
        for i in range(num_layers - 1):
            layer = Compared_layer(edge_feature_size, num_antenna, num_BS, power, noise)
            self.middle_layers.append(layer)

    def forward(self, F_ue0, E, P, Noise):
        F_ue = self.layer1(F_ue0, E, P, Noise) #F_ue is random for the input of the first layer
        if self.device.type == "cuda":
            F_ue = F_ue.to(self.device)
        for layer in self.middle_layers:
            F_ue = layer(F_ue, E, P, Noise)
            if self.device.type == "cuda":
                F_ue = F_ue.to(self.device)

        return F_ue

The customized loss and train function:

def Loss(W, H, Noise):
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    K = Noise.size()
    K = K[0]
    N = H.size(1) // K
    H = H.to(torch.complex128)
    W = W.to(torch.complex128)
    R = torch.zeros(K)
    interference = torch.zeros(K)
    SINR = torch.zeros(K)
    if device.type == "cuda":
        Noise = Noise.to(device)
    for k in range(K):        
        h_k = H[:, k * N:(k + 1) * N]  # of size [M, N]
        Signal = 0.0
        h_kT = torch.transpose(torch.conj(h_k), 0, 1) # of size [N, M]
        h_kT = torch.flatten(h_kT).requires_grad_(True)  # of size [MN, 1]
        hw_k = torch.matmul(h_kT, W[:, k])
        hw_k.retain_grad()      
        signal = torch.sum(hw_k)  # h_kT * w_k
        Signal = abs(signal) ** 2
        for l in range(K):
            interference_k = 0.0  # h_kT * w_l
            if l!=k:
                hw_l = torch.matmul(h_kT.view(-1), W[:, l])
                #hw_l.retain_grad()
                interference_k = torch.sum(hw_l)
                #interference_k = torch.matmul(h_kT.view(-1), W[:, l]).sum()
            interference[k] = interference[k] + abs(interference_k) ** 2

        SINR[k] = Signal / (interference[k] + Noise[k])
        R[k] = torch.log2(1+SINR[k])

    Rsum = -torch.sum(R)

    return Rsum

def getW(F_ue, P, K):
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    M = P.size()
    P = P.to(device)
    M = M[0]
    N = F_ue.size(0) // (2*M)
    F_ue_complex = F_ue[:M * N, :] + 1j * F_ue[M * N:, :]   #before normalization
    W = torch.zeros(M*N, K, device=device) + 1j * torch.zeros(M*N, K, device=device)
    W_new = torch.zeros(W.size(), device=device) + 1j*torch.zeros(W.size(), device=device)
    W = torch.autograd.Variable(W, requires_grad=True)
    for m in range(M):
        W_temp = F_ue_complex[m*N:(m+1)*N, :].clone().to(device) # of size N*K
        norm_sum = 0.0
        for k in range(K):
            norm_k = torch.norm(W_temp[:,k], p=2) ** 2
            norm_sum += norm_k
        W_new[m*N:(m+1)*N, :] = P[m] * W_temp/norm_sum
    W = W_new
    W = W.to(device)
    return W

def train(P, Noise, dataset, num_epochs, lr):
    M = P.size()
    M = M[0]
    K = Noise.size()
    K = K[0]
    KN = dataset.size(dim=2)
    N = KN // K
    model = ComparedGNN(edge_feature_size=2*N, num_antenna=N, num_BS=M, num_layers=2, power=P[0], noise=Noise[0])
    model.to(device)
    model.train()
  
    #optimizer = torch.optim.Adam(model.parameters(), lr)
    optimizer = torch.optim.RMSprop(model.parameters(), lr)

    train_size = int(1 * len(dataset))
    print(f'train size: {train_size}')
    # print(f'train size: {train_size}')
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

    batch_size = 64
    num_batches = (train_size + batch_size - 1) // (batch_size * num_epochs)
    print(f'num_batches: {num_batches}')

    # define lists to save loss and rate for each epoch
    losses = []
    rate = []

    for epoch in range(num_epochs):
        running_loss = 0.0
        for batch_idx in range(num_batches):  # number of minibatches = num_batches
            model.train()
            start_idx = (epoch * num_batches + batch_idx) * batch_size
            end_idx = start_idx + batch_size
            subset_indices = range(start_idx, end_idx)
            subset = Subset(train_dataset, subset_indices)
            subset_loader = DataLoader(subset, batch_size=batch_size, shuffle=True)
            # obtain data from each batch
            for data in subset_loader:
                channel_complex = data
                if device.type == "cuda":
                    channel_complex = channel_complex.to(device)
                optimizer.zero_grad()
                loss = 0.0
                batch_n = channel_complex.size(dim=0)

                for b in range(batch_n):
                    H_complex = channel_complex[b,:,:]
                    # transfer the complex into real to be the edge feature
                    edge_feature = torch.zeros(M, K, 2*N)
                    for m in range(M):
                        for k in range(K):
                            h_complex = H_complex[m, k*N:(k+1)*N]
                            h_real = h_complex.real
                            h_imag = h_complex.imag
                            edge_feature[m,k,:] = torch.cat((h_real, h_imag), dim=0)

                    F_ue0 = torch.rand(2*M*N, K)
                    if device.type == "cuda":
                        F_ue0 = F_ue0.to(device)
                        edge_feature = edge_feature.to(device)

                    F_ue = model(F_ue0, edge_feature, P, Noise)
                    W = getW(F_ue, P, K)
                    loss = loss + Loss(W, H_complex, Noise)

                loss /= batch_n

                if device.type == "cuda":
                    loss = loss.cpu()
                running_loss += loss.item()
          
                for name, param in model.named_parameters():
                    if param.grad is None:
                        print(f'Parameter: {name}, Gradient: {param.grad}')

                loss.backward()
                optimizer.step()

        epoch_loss = running_loss/num_batches
        losses.append(epoch_loss)
        rate.append(-epoch_loss)
        print(f'Epoch: {epoch + 1:03d}, Training Loss: {epoch_loss:.4f}')

Now, I’m thinking maybe my mlp settings are wrong. I use clone().detach() to make sure the graph is retained. But I’m not sure if it is the problem. @ptrblck , Hi sir, could you please have a look? I got frozen here. :smiling_face_with_tear:
Thanks!!!

import torch
from torch.nn import Linear
import numpy as np

class Compared_layer(torch.nn.Module):
    def __init__(self, edge_feature_size, num_antenna, num_BS, power, noise):
        '''
        :param edge_feature_size: int, the dimension of edge_feature, e_mk of size 2*N
        :param num_antenna: int, number of antennas at each BS
        :param num_BS: int, number of BS
        :param power: int, P[m]
        :param noise: int, Noise[k]
        '''
        super(Compared_layer, self).__init__()

        # MLP for generating message at BS
        self.mlp1 = torch.nn.Sequential(
            Linear(edge_feature_size + 2*num_BS*num_antenna + 2, 512), # power and noise should also be input
            torch.nn.ReLU(),
            Linear(512, 512),
            torch.nn.ReLU(),
            Linear(512, 512),
            torch.nn.ReLU(),
            Linear(512, 2*num_BS*num_antenna)
        )

        # MLP for generating message at UE
        self.mlp2 = torch.nn.Sequential(
            Linear(2*num_BS*num_antenna + edge_feature_size + 2*num_BS*num_antenna + 2, 512),
            torch.nn.ReLU(),
            Linear(512, 512),
            torch.nn.ReLU(),
            Linear(512, 512),
            torch.nn.ReLU(),
            Linear(512, 2*num_BS*num_antenna)
        )

        # MLP for updating the UE representation
        self.mlp3 = torch.nn.Sequential(
            Linear(2*num_BS*num_antenna + num_BS*edge_feature_size + 2*num_BS*num_antenna + 2, 512),
            torch.nn.ReLU(),
            Linear(512, 512),
            torch.nn.ReLU(),
            Linear(512, 512),
            torch.nn.ReLU(),
            Linear(512, 2*num_BS*num_antenna)
        )

        self.M = num_BS
        self.N = num_antenna
        self.K = 2
        
        self.message_bs = torch.zeros(self.M, self.K, 2 * self.M * self.N)
        self.message_ue = torch.zeros(self.K, self.M, 2 * self.M * self.N)
       
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.to(self.device)
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, Linear):
                torch.nn.init.xavier_uniform_(m.weight, gain=1.0)  # 设置gain参数为非零值
                torch.nn.init.constant_(m.bias, 0.4)  # 设置非零的偏置值
        
    def forward(self, F_ue, E, P, Noise):
        agg_bs, agg_ue = self.aggregation()
        self.message(F_ue, E, agg_ue, P, Noise)
        agg_bs, agg_ue = self.aggregation()
        F_ue_update = self.update(E, F_ue, agg_bs, P, Noise)
        return F_ue_update

    def message(self, F_ue, E, agg_ue, P, Noise): 
        if self.device.type == "cuda":
            P = P.to(self.device)
            Noise = Noise.to(self.device)
        for m in range(self.M):
            for k in range(self.K):
                self.message_bs[m, k, :] = self.mlp1(torch.cat((E[m,k,:], agg_ue[m,:], P[m], Noise[k]), 0)).clone().detach()

        agg_bs = torch.mean(self.message_bs, dim=0)
        agg_bs = agg_bs.to(self.device) # of size [K, 2MN]
        for k in range(self.K):
            for m in range(self.M):
                self.message_ue[k,m,:] = self.mlp2(torch.cat((F_ue[:, k], E[m,k,:], agg_bs[k,:], P[m], Noise[k]), 0)).clone().detach()
                
        # if cuda is available
        if self.device.type == "cuda":
            self.message_bs = self.message_bs.to(self.device)
            self.message_ue = self.message_ue.to(self.device)
        #print("message_bs", self.message_bs)
        #print("message_ue", self.message_ue)

    def aggregation(self):
        agg_bs = torch.mean(self.message_bs, dim=0) # of size [K, 2MN]
        agg_ue = torch.mean(self.message_ue, dim=0) # of size [M, 2MN]

        agg_bs = agg_bs.to(self.device)
        agg_ue = agg_ue.to(self.device)    

        return agg_bs, agg_ue

    def update(self, E, F_ue, agg_bs, P, Noise):
        if self.device.type == "cuda":
            P = P.to(self.device)
            Noise = Noise.to(self.device)
        F_ue_update = torch.zeros(F_ue.shape)
        for k in range(self.K):
            edge_cat_ue = torch.cat([E[m, k, :] for m in range(self.M)], dim=0)
            F_ue_update[:, k] = self.mlp3(torch.cat((edge_cat_ue, F_ue[:, k], agg_bs[k,:], P[0], Noise[k]), 0)).clone().detach()
        F_ue_update = F_ue_update.to(self.device)
        #print("F_ue_update", F_ue_update)
    
        return F_ue_update

I haven’t completely understood your code, but here are a few issues:

  • Initialize self.middle_layers as an nn.ModuleList to properly register the submodules. A plain Python list will not register the trainable parameters and the Compared_layers won’t be trained.
  • Creating a new tensor (or the deprecated Variable) creates a new leaf tensor without any gradient history, so you are detaching the computation graph at this point. getW creates a new tensor and I’m unsure if W should be trained or not. Note that W won’t be attached to any computation graph.
  • I don’t know why you want to retain the graph but explicitly calling detach() on a tensor will of course detach it from the computation graph.

Thanks for your reply.

  • My goal is to reduce the loss by optimizing W based on P, Noise and the dataset (edge_feature or H or channel). Since W should be complex and mlps can only take real numbers, I need to convert the outputs (F_ue) from Compared_layers into complex (W) in getW. In this case, I think W should be trained to decrease the loss.

  • If I wan t to train W how should I create a new tensor while making sure that W is attached to the computation graph?

  • I call detach() for every mlp to avoid in-place operation. Error messages say in-place operations happened if I only use: self.message_bs[m, k, :] = self.mlp1(torch.cat((E[m,k,:], agg_ue[m,:], P[m], Noise[k]), 0))

instead of

self.message_bs[m, k, :] = self.mlp1(torch.cat((E[m,k,:], agg_ue[m,:], P[m], Noise[k]), 0)).clone().detach()