Implementation of Coverage loss, loss is always 1

I am trying to implement coverage loss like: https://arxiv.org/pdf/1704.04368.pdf
activations shape: [batch, timesteps, features]
mask shape: [batch, timesteps] to recognize the padding

For example, If I want to calculate the coverage loss for the 5ith element in the batch. I am taking the element [5, timesteps, features]. I iterate all the timesteps as follows:

  • for the 1st timestep, take the min between [zero_vector, activations[5][1]
  • for the 10th timestep, take the min between [sum_of_all_previous_activations, activations[5][10]]

every time I change sample in the batch I repeat the following procedure.

  • In the end the coverage loss is always 1
import torch
torch.manual_seed(150)
torch.set_default_dtype(torch.float64)

mask = torch.ones(10,10)
activations = torch.rand(10,10,3518)
activations_distribution = torch.nn.functional.softmax(activations,dim=2)
overall_coverage = 0#initialize the overall loss for the batch
or index, activation in enumerate(activations_distribution): #iterate every batch
        
                loss = 0 #initialize the coverage loss for a sentence in the batch
                overall_weights = torch.zeros(activation.shape[1]).to(device)#the first coverage vector equals to zero
                
                for time, timestep in enumerate(activation): #iterate every single timestep           
                        
                        if (mask[index][time]==0):#it means that we foudn the first padding symbol, so there are only padding symbols from this point till the end of sequence, se we stop
                                break
                     
                        loss = torch.sum(torch.min(overall_weights,timestep)) #find the min element of the current dimensions, compare elementwise
                        overall_weights = torch.sum(activation[:time+1,:],dim=0).to(device) #after timestep = 0 || now we are taking the (unormalized)sum of the previous activations
                        
                overall_coverage = overall_coverage + loss # keep track of the general conerage loss of the batch
                
        cov_loss = overall_coverage/activations_distribution.shape[0] #divide by the batch size
print(index,'loss',cov_loss)

The problem was that the classes(in the previous example was 3518) this big number of classes conclude to very small number after the softmax pass. So every time the sum of the previous softmax outputs was all the time bigger than the current output. This led to taking all the time as min the element of the current softmax output concluding in just the sum of all the elements of the current softmax output, which every time is 1.

To change this we have to take the mean of all the previous sums of softmax output.

overall_weights = ((torch.sum(activation[:time+1,:],dim=0))/(time+1)).to(device)