NANs appear immediately with custom loss function

Hi there,
I am attempting to implement a custom loss function that utilizes a novel measure of information. The loss function is here

def MI_loss(self, inputs, E0 = 188, E = 1000):   # TODO: MAKE THIS FASTER in pc_info
        '''

        Args:
            inputs: Batch of 2d velocity inputs with shape [batch_size, sequence_length, 2].

        Returns: 
            MI_loss: Avg. MI information loss for this batch
        '''
        preds = self.predict(inputs)
        Nx, batch_size, Np = preds.shape
        
        dist = 1 / Nx * torch.ones(Nx) ## Option to change this for more interesting distributions
        M = pc_info.MI_spike(preds, dist)
        
        M = M.to(self.device)
    
        loss = torch.norm(M + eps, p = 'fro')
        
        # ENERGY CONSTRAINT
        
        total_ave_frs = dist @ preds.mean(dim = 1)
        
        energy_constraint_loss = relu((E0 * total_ave_frs - E))
        energy_loss = torch.mean(energy_constraint_loss)
    
        
        return loss + energy_loss

Where pc_info.MI_spike is defined by the following few snippets of code:

def I_spike(pc, dist):
    
    pc = torch.relu(pc)
    pc_nonzero_indices = pc.nonzero(as_tuple = False)

    info_matrix = torch.zeros(pc.shape)
    norm_info = torch.zeros(pc.shape)

    pc_nonzero = pc[pc_nonzero_indices[:, 0], pc_nonzero_indices[:, 1], pc_nonzero_indices[:, 2]]

    dist_nonzero = dist[pc_nonzero_indices[:, 0]]

    spike_rate_per_cell = (pc * dist.view(-1, 1, 1)).sum(dim = 0)
    
    spike_rate_matrix = spike_rate_per_cell.unsqueeze(0).expand(pc.shape[0], -1, -1)
    
    l = spike_rate_matrix[pc_nonzero_indices[:, 0], pc_nonzero_indices[:, 1], pc_nonzero_indices[:, 2]]
    
    info = pc_nonzero * torch.log2(pc_nonzero / l + eps) * dist_nonzero
    info_matrix[pc_nonzero_indices[:, 0], pc_nonzero_indices[:, 1], pc_nonzero_indices[:, 2]] = info

    info_per_cell = info_matrix.sum(dim = 0)

    norm_info = info_per_cell * 1/spike_rate_per_cell

    norm_info[spike_rate_per_cell == 0] = 0

    # Average the Skaggs information across batches
    batch_ave_norm_info = norm_info.mean(dim = 0)
    
    return batch_ave_norm_info

def I_spike_joint(pc, dist): 
    Nx = pc.shape[0]
    batch_size = pc.shape[1]
    Np = pc.shape[2]
    all_Js = torch.zeros((Np, Np))
    
    for obs in range(batch_size):
        
        J = torch.zeros((Np, Np))
                     
        for i in range(Np):
            for j in range(i+1, Np):
                pc1 = torch.relu(pc[:, obs, i])
                pc2 = torch.relu(pc[:, obs, j])
                
                if torch.std(pc1) ==0 :
                    J[i, j] =  I_spike(pc2.reshape((Nx, 1, 1)), dist)
                elif torch.std(pc2) == 0:
                    J[i, j] = I_spike(pc2.reshape((Nx, 1, 1)), dist)
                    
                else:
                    r = pc_corrcoef(pc1, pc2, dist)
                    lab = pc1 * pc2
                    lab_tilde = (dist * torch.sqrt(pc1*pc2)).sum()
                    la = (dist * pc1).sum()
                    lb = (dist * pc2).sum()
                   
                    
                    info_1 = dist * (r * (torch.sqrt(lab + eps)/lab_tilde)*torch.log2(torch.sqrt(lab + eps) / lab_tilde + eps)) 
                    info_2 = dist * ((pc1 - r*torch.sqrt(lab + eps))/(la - r*lab_tilde)*torch.log2((pc1 - r*torch.sqrt(lab + eps)) / (la - r*lab_tilde) +  eps))
                    info_3 = dist*((pc2 - r*torch.sqrt(lab + eps)) / (lb - r*lab_tilde)*torch.log2((pc2 - r*torch.sqrt(lab + eps)) / (lb - r*lab_tilde) + eps))
                    
                    # TODO: Question validity of below
                    info_1[(lab == 0) | (torch.sqrt(lab)/lab_tilde < 0)] = 0
                    info_2[((pc1 - r*torch.sqrt(lab)) == 0) | ((pc1 - r*torch.sqrt(lab)) / (la - r*lab_tilde) < 0)] = 0
                    info_3[((pc2 - r*torch.sqrt(lab)) == 0)| ((pc2 - r*torch.sqrt(lab)) / (lb - r*lab_tilde) < 0)] = 0
                    
                    
                    info = info_1 + info_2 + info_3
                    
                    J[i,j] = info.sum()
                
        all_Js += J
        
    return 1/batch_size * all_Js

def MI_spike(pc, dist):
    Np = pc.shape[2]
    batch_size = pc.shape[1]
    Nx = pc.shape[0]
    all_Ms = torch.zeros((Np, Np))
    
    for obs in range(batch_size):
        M = torch.zeros((Np, Np))
        J = I_spike_joint(pc[:,obs,:].reshape((Nx, 1, Np)), dist)
        for i in range(Np):
            for j in range(i+1, Np):
                pc1 = pc[:, obs, i].reshape((Nx, 1, 1))
                pc2 = pc[:, obs, j].reshape((Nx, 1, 1))
                
                M[i,j] = I_spike(pc1, dist) + I_spike(pc2,dist) - J[i,j]
                
        all_Ms += M
            
    return 1/batch_size * all_Ms

I am attempting to use an RNN to produce outputs of size [Nx, batch_size, Np], which are then used to calculate the information measure as described above. The architecure of my RNN is below:

class RNN(torch.nn.Module):
    def __init__(self, options):
        super(RNN, self).__init__()
        self.Ng = options.Ng
        self.Np = options.Np
        self.sequence_length = options.sequence_length
        self.weight_decay = options.weight_decay
        self.device = options.device

        # Input weights
        self.encoder = torch.nn.Linear(self.Np, self.Ng, bias=False)
        
        self.RNN = torch.nn.RNN(input_size=2,
                                hidden_size=self.Ng,
                                nonlinearity=options.activation,
                                batch_first=False,
                                bias=False)
        
        # Linear read-out weights
        self.decoder = torch.nn.Linear(self.Ng, self.Np, bias=False)
        
        self.softmax = torch.nn.Softmax(dim=-1)

    def g(self, inputs):
        '''
        Compute grid cell activations.
        Args:
            inputs: Batch of 2d velocity inputs with shape [batch_size, sequence_length, 2].

        Returns: 
            g: Batch of grid cell activations with shape [batch_size, sequence_length, Ng].
        '''
        v, p0 = inputs  #TODO: Maybe change this, whwere does p0 come from? 
        init_state = self.encoder(p0)[None]
        
        g,_ = self.RNN(v, init_state)
        
        return g
    

    def predict(self, inputs):
        '''
        Predict place cell code.
        Args:
            inputs: Batch of 2d velocity inputs with shape [batch_size, sequence_length, 2].

        Returns: 
            place_preds: Predicted place cell activations with shape 
                [batch_size, sequence_length, Np].
        '''
        
        place_preds = self.decoder(self.g(inputs))
        
        return place_preds

Training results in nan after just one iteration (the first loss value is a reasonable value, but then everything immediately goes to nan. Running anomaly detection resulted in the following error:
RuntimeError: Function 'LinalgVectorNormBackward0' returned nan values in its 0th output.

Please help if there is any issue with my loss function or architecture that is causing this. Gradient clipping hasn’t helped.

Thanks

Hi @ptrblck are there any reasons you think this would be occurring with the given loss function? I think it has something to do with when the gradients are computed. I have tried to allow my loss to work with zeros and negative numbers.

I would start by checking torch.log2(pc_nonzero / l + eps) to see if this is returning invalid outputs. Also, do you want to add the eps to the divisor or to the quotient?