I am trying to implement coverage loss like: https://arxiv.org/pdf/1704.04368.pdf
activations shape: [batch, timesteps, features]
mask shape: [batch, timesteps] to recognize the padding
For example, If I want to calculate the coverage loss for the 5ith element in the batch. I am taking the element [5, timesteps, features]. I iterate all the timesteps as follows:
- for the 1st timestep, take the min between [zero_vector, activations[5][1]
- for the 10th timestep, take the min between [sum_of_all_previous_activations, activations[5][10]]
every time I change sample in the batch I repeat the following procedure.
- In the end the coverage loss is always 1
import torch
torch.manual_seed(150)
torch.set_default_dtype(torch.float64)
mask = torch.ones(10,10)
activations = torch.rand(10,10,3518)
activations_distribution = torch.nn.functional.softmax(activations,dim=2)
overall_coverage = 0#initialize the overall loss for the batch
or index, activation in enumerate(activations_distribution): #iterate every batch
loss = 0 #initialize the coverage loss for a sentence in the batch
overall_weights = torch.zeros(activation.shape[1]).to(device)#the first coverage vector equals to zero
for time, timestep in enumerate(activation): #iterate every single timestep
if (mask[index][time]==0):#it means that we foudn the first padding symbol, so there are only padding symbols from this point till the end of sequence, se we stop
break
loss = torch.sum(torch.min(overall_weights,timestep)) #find the min element of the current dimensions, compare elementwise
overall_weights = torch.sum(activation[:time+1,:],dim=0).to(device) #after timestep = 0 || now we are taking the (unormalized)sum of the previous activations
overall_coverage = overall_coverage + loss # keep track of the general conerage loss of the batch
cov_loss = overall_coverage/activations_distribution.shape[0] #divide by the batch size
print(index,'loss',cov_loss)