Log softmax probabilities all equal in rnn decoder because pointer network scores are all < -90.0

I have a strange problem. I am training an RNN seq2seq encoder decoder model for text generation. 5-grams in, text out. So, first, 5-grams are encoded and gated, the encoded plan r_cs is a matrix of size →[batch x number-of-5-grams x d_hidden]. This is input to an rnn decoder, which utilizes a pointer network (Vinyals et al.) in order to select items from its input, namely from the record representations along dimension 1 (number-of-5-grams). The record which the log-softmaxed scores point to is used as input to the next step as well as appended to the output. The implementation bases on this paper: https://arxiv.org/pdf/1809.00582.

Now, after this plan decoder, there is another rnn encoder for the chosen records, which are then input to an lstm-based text decoder which produces text. The loss is composed of the text decoder loss as well as the planner decoder loss (comparing the predicted record with the one in the target plan). During training, the text loss decreases, until it plateaus at about 0.4. the planner decoder loss does only change for the first 2-3 batches, then it is constant at 4.39. I looked into the issue and found out that after the first batches, soft-maximized probabilities for the records are all identical, so choice is random. The reason for this is that a matrix multiplication in this formula:

p(zk = rj |z<k, r) ∝ exp(hᵀ_k W_c r_cs_j )

where r_cs is the encoded record, h^T_k is the rnn decoder hidden state and W_c is a learnable matrix.
This computation is done in 2 steps (right_product and left_product).
In the following implementation, the right product is ok, but the left product starts out with resulting tensor elements between -3 to 3, then after a few batches, becomes a matrix of values all below -90.0, except for the very last score, which is always positive (the end pf plan marker, which is also used to pad, it occurs really often.) The resulting predicted probability soon get to -4.3944 for all records, also for the last padding element. Also, one observation I made is that W_c does not change after backpropagating the loss. Attention is Bahdanau attention.

class PlannerDecoder(torch.nn.Module):

  def __init__(self, d_hidden, d_input, num_layers, d_record, device=torch.device("cuda:0")):

    super(PlannerDecoder, self).__init__()

    self.device = device
    self.d_hidden = d_hidden
    self.d_input = d_input
    self.num_layers = num_layers
    self.d_record = d_record


    self.lstm_layer = torch.nn.LSTM(2*self.d_hidden,
                                self.d_hidden,
                                num_layers=self.num_layers,
                                bias=True,
                                batch_first=True,
                                dropout=0.0,
                                bidirectional=False, device=device)

    self.attention_layer = Attention(self.d_hidden, device=device)

    # specific uniform distribution needed to prevent strange effects
    # of activation function and softmax

    stdv = 1. / math.sqrt(self.d_hidden)
    self.W_c = torch.empty([self.d_hidden, self.d_hidden], device=device).uniform_(-stdv, stdv)
    self.W_c.requires_grad=True



  def forward(self, r_cs, r_index, max_plan_length, prev_state=False, target_tensor=False):
    

    #print("W_c : ", self.W_c)
    #print("planner target: ", target_tensor, target_tensor.shape)
    
    # target tensor must be a tensor of indices over the whole set of records in the
    # data wrt to the target records realized in the text
    # e.g. the set of all records is 50, then a possible target tensor or gold plan
    # could be [14, 41, 2, 7, 19, 8, <EOP>, ... <EOP>] 
    
    # compute h0 as average of input records
    # sum along axis d_input and divide for average
    state_h = torch.div(r_cs.sum(dim=1), self.d_input) #/ self.d_input
    state_c = torch.div(r_cs.sum(dim=1), self.d_input)
    
    state_h = state_h.unsqueeze(0)
    state_c = state_c.unsqueeze(0)

    decoder_hidden_state = (state_h, state_c)
    

    document_plan = []
    idx_document_plan = []
    idx_plan_predictions = []
    decoder_outputs = []
    attention_weights = []
    pointer_outputs = []

    batch_size = r_cs.shape[0]

    # get start record as the average of all records

    #decoder_input = torch.randn([batch_size, 1, self.d_hidden]).to(device)
    #  print(decoder_input.shape)

    decoder_input = torch.div(r_cs.sum(dim=1), self.d_input).unsqueeze(1).detach()
    #print(decoder_input.shape)
    teacher_forcing_ratio = 0.5
    
    # iterate over allowed number of inference steps

    for i in range(max_plan_length):

        #print(f"planner step {i}")

        use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False

        # compute one iteration step
        decoder_output, decoder_hidden_state, attn_weights = self.step(
            decoder_input, decoder_hidden_state, r_cs)

        # collect decoder output and attention weights for this step
        decoder_outputs.append(decoder_output)
        attention_weights.append(attn_weights)

################ Problem starts here
################

        # pointer network produces probability distribution over input length
        # instead of vocab for selecting records for the document plan
        # h_k^T ° Wc ° r_cs_j

        # h_k = [1 x batch_size x d_hidden] or [batch_size x 1 x d_hidden]
        # r_cs_ = [batch_size x d_input x d_hidden]
        # W_c must be [d_hidden x d_hidden] in order for the output to produce
        # tensor [batch_size x d_input x 1]
        
        right_product = torch.matmul(r_cs, self.W_c)
        # [batch_size x d_input x d_hidden]
        print("right: ", right_product, right_product.shape)
        
        left_product = torch.matmul(decoder_output, right_product.transpose(2,1))

        print("left: ", left_product, left_product.shape)
       
        scores = torch.sigmoid(left_product)
        print("scores: ", scores)
        
        # [batch_size x 1 x d_input]
        pointer_outputs.append(scores)

        # softmax over input

        pred = torch.nn.functional.log_softmax(scores, dim=-1)

#################
#################

        print("pred: ", pred, pred.shape)

        if target_tensor is not None and use_teacher_forcing:

            # Teacher forcing: Feed the target as the next input
            topi = target_tensor[:, i].unsqueeze(1) # Teacher forcing
           
        else:
            # Without teacher forcing: use its own predictions as the next input
            _, topi = pred.topk(1)
           
            topi = topi.squeeze(-1)
           
        # just a tensor with batch indices
        batch_i = torch.arange(0, batch_size, dtype=torch.long)
       

        # record index
        record_i = topi.squeeze(-1)
       

        # use batch and record index to retrieve tensor representation for
        # the respectively predicted document plan record

        decoder_input = r_cs[batch_i, record_i, :].detach()  # detach from history as input
        
        decoder_input = decoder_input.unsqueeze(1)
     
        # use batch and record index to retrieve the token id vectors
        # for the respectively predicted document plan record

        idx_decoder_input = r_index[batch_i, record_i, :]
        
        idx_decoder_input = idx_decoder_input.unsqueeze(1)

       


        # collect all predictions and document plan elements
        #print("prediction: ", topi)
        idx_document_plan.append(idx_decoder_input)
        idx_plan_predictions.append(topi)
        document_plan.append(decoder_input)

    pointer_outputs = torch.cat(pointer_outputs, dim=1)
   
    idx_plan_predictions = torch.cat(idx_plan_predictions, dim=1)

    pointer_outputs = torch.nn.functional.log_softmax(pointer_outputs, dim=-1)
   
    attention_weights = torch.cat(attention_weights, dim=1)

    document_plan = torch.cat(document_plan, dim=1)
    
    idx_document_plan = torch.cat(idx_document_plan, dim=1)
   
    return document_plan, idx_document_plan, idx_plan_predictions, pointer_outputs, decoder_hidden_state, attention_weights


  def step(self, input_record, prev_state, r_cs):

    decoder_h, decoder_c = prev_state

    #print("STEP decoder h: ", decoder_h.shape)
    s_tneg1 = decoder_h.permute(1, 0, 2)

    #print("STEP r_cs: ", r_cs.shape)


    context, attn_weights = self.attention_layer(s_tneg1, r_cs)

    #print("STEP: context: ", context.shape)
    #print("STEP: input_record: ", input_record.shape)

    lstm_input = torch.cat((input_record, context), dim=2)


    decoder_output, (decoder_h, decoder_c) = self.lstm_layer(lstm_input, (decoder_h, decoder_c))



    return decoder_output, (decoder_h, decoder_c), attn_weights