How can i deal with seq with diff len in decoder

this is a pointer net model implements by me,

in one mini-batch,there are diff_len seqences ,and this sequences paddding to a max length

how to deal these seq to decoder, i don’t kown the padding effect on decoder in backward?

suppose that in decoder time t1, seq_1,seq_2 in a mini-batch,
seq_1 is stop in last time,and seq_2 need to calc

i select the seq_2 to calc loss, drop seq_1
is this right?

from typing import Tuple
import torch.nn.functional as F
import torch
import torch.nn as nn


class LSTMEncoder(nn.Module):
    def __init__(self, embedding_dim, hidden_size, num_layers=1, batch_first=True, bidirectional=False):
        super(LSTMEncoder, self).__init__()

        self.batch_first = batch_first
        self.bidirectional = bidirectional
        self.num_layers = num_layers
        self.embedding_dim = embedding_dim
        self.num_directions = 2 if self.bidirectional else 1
        self.hidden_size = int(hidden_size / self.num_directions)

        self.rnn = nn.LSTM(input_size=embedding_dim, hidden_size=self.hidden_size, num_layers=num_layers,
                           batch_first=batch_first, bidirectional=bidirectional)  # nn.LSTM(512, 256, 3)

    def forward(self, embedded_inputs, input_lengths,
                max_len):  # embedded_inputs:(64, 25, 512), input_lengths:(64,), max_len:25
        # Pack padded batch of sequences for RNN module
        packed = nn.utils.rnn.pack_padded_sequence(embedded_inputs, input_lengths.view(-1).cpu(),
                                                   batch_first=self.batch_first,
                                                   enforce_sorted=False)  # 记录需要mask的部分,有效len之外的部分不经过rnn
        # Forward pass through RNN
        outputs, hidden = self.rnn(packed)  # hidden:(6, 64, 256)
        # Unpack padding
        outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=self.batch_first)  # outputs:(64, 22, 512)
        # Unpack函数只能padding至当前batch最大长度,需继续pad至全局最大长度
        extra_padding_size = max_len - outputs.shape[1]
        outputs = nn.functional.pad(outputs, [0, 0, 0, extra_padding_size, 0, 0], mode="constant", value=0)  # 第2维进行下填充

        # Return output and final hidden state
        # if self.bidirectional:
        #     # Optionally, Sum bidirectional RNN outputs
        #     # outputs = outputs[:, :, :self.hidden_size] + outputs[:, :, self.hidden_size:]
        #     outputs = torch.cat((outputs[:, :, :self.hidden_size], outputs[:, :, self.hidden_size:]), dim=2)
        # batch_size = embedded_inputs.size(0)
        # h_n, c_n = hidden
        # h_n = h_n.view(self.num_layers, self.num_directions, batch_size, self.hidden_size)
        # c_n = c_n.view(self.num_layers, self.num_directions, batch_size, self.hidden_size)
        # if self.bidirectional:
        #     f = (h_n[-1, 0, :, :].squeeze(), c_n[-1, 0, :, :].squeeze())  # f: tuple:2, ((64, 256), (64, 256))
        #     b = (h_n[-1, 1, :, :].squeeze(), c_n[-1, 1, :, :].squeeze())  # b: tuple:2, ((64, 256), (64, 256))
        #     hidden = (torch.cat((f[0], b[0]), dim=1), torch.cat((f[1], b[1]), dim=1))  # tuple:2, ((64, 512), (64, 512))
        # else:
        #     hidden = (h_n[-1, 0, :, :].squeeze(), c_n[-1, 0, :, :].squeeze())

        return outputs, hidden


class Attention(nn.Module):
    def __init__(self, hidden_size, units):
        super(Attention, self).__init__()
        self.W1 = nn.Linear(hidden_size, units, bias=False)
        self.W2 = nn.Linear(hidden_size, units, bias=False)
        self.V = nn.Linear(units, 1, bias=False)

    def forward(self,
                encoder_out: torch.Tensor,
                decoder_hidden: torch.Tensor):
        # encoder_out: (BATCH, ARRAY_LEN, HIDDEN_SIZE)
        # decoder_hidden: (BATCH, HIDDEN_SIZE)

        # Add time axis to decoder hidden state
        # in order to make operations compatible with encoder_out
        # decoder_hidden_time: (BATCH, 1, HIDDEN_SIZE)
        decoder_hidden_time = decoder_hidden.unsqueeze(1)

        # uj: (BATCH, ARRAY_LEN, ATTENTION_UNITS)
        # Note: we can add the both linear outputs thanks to broadcasting
        uj = self.W1(encoder_out) + self.W2(decoder_hidden_time)
        uj = torch.tanh(uj)

        # uj: (BATCH, ARRAY_LEN, 1)
        uj = self.V(uj)

        # Attention mask over inputs
        # aj: (BATCH, ARRAY_LEN, 1)
        # aj = F.softmax(uj, dim=1)
        #
        # # di_prime: (BATCH, HIDDEN_SIZE)
        # di_prime = aj * encoder_out
        #
        # di_prime = di_prime.sum(1)

        return uj.squeeze(-1), uj.squeeze(-1)


# Attention weights over inputs: torch.Size([4, 6])

class RawDecoder(nn.Module):
    def __init__(self,
                 hidden_size: int,
                 attention_units: int = 10):
        super(RawDecoder, self).__init__()
        self.lstm = nn.LSTM(hidden_size + 1, hidden_size, batch_first=True)
        self.attention = Attention(hidden_size, attention_units)

    def forward(self,
                x: torch.Tensor,
                hidden: Tuple[torch.Tensor],
                encoder_out: torch.Tensor):
        # x: (BATCH, 1, 1)
        # hidden: (1, BATCH, HIDDEN_SIZE)
        # encoder_out: (BATCH, ARRAY_LEN, HIDDEN_SIZE)

        #  ht_ last timqe hidden state
        ht = hidden[0][0]  # ht: (BATCH, HIDDEN_SIZE) h_n

        # di: Attention aware hidden state -> (BATCH, HIDDEN_SIZE)
        di, att_w = self.attention(encoder_out, ht)

        # Append attention aware hidden state to our input
        # x: (BATCH, 1, 1 + HIDDEN_SIZE)
        x = torch.cat([di.unsqueeze(1), x], dim=2)

        # Generate the hidden state for next timestep
        _, hidden = self.lstm(x, hidden)
        return hidden, att_w


class AttDecoder(nn.Module):
    def __init__(self,
                 hidden_size: int, emb_size: int,
                 attention_units: int = 10):
        super(AttDecoder, self).__init__()
        self.lstm = nn.LSTM(emb_size, hidden_size, batch_first=True)
        self.attention = Attention(hidden_size, attention_units)

    def forward(self,
                x: torch.Tensor,
                hidden: Tuple[torch.Tensor],
                encoder_out: torch.Tensor):
        _, hidden = self.lstm(x, hidden)
        ht = hidden[0][0]  # ht: (BATCH, HIDDEN_SIZE) h_n

        # di: Attention aware hidden state -> (BATCH, HIDDEN_SIZE)
        di, att_w = self.attention(encoder_out, ht)

        # Append attention aware hidden state to our input
        # x: (BATCH, 1, 1 + HIDDEN_SIZE)
        # Generate the hidden state for next timestep

        return hidden, att_w


class PointerNetwork(nn.Module):
    def __init__(self,
                 encoder: nn.Module,
                 decoder: nn.Module, max_len: int = 25):
        super(PointerNetwork, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.max_len = max_len

    def forward(self,
                x: torch.Tensor,
                y: torch.Tensor, x_start: torch.Tensor,
                batch_len: torch.Tensor,
                teacher_force_ratio=.5):
        max_len = torch.tensor([self.max_len]).long()
        encoder_in = x  # Batch L Hid[20, 25, 64])
        # Batch,Len,Hidden
        out, hs = self.encoder(encoder_in, batch_len, max_len)
        loss = 0
        # Len,Batch
        outputs = torch.zeros(out.size(1), out.size(0), dtype=torch.long)
        # First decoder input is always 0
        # dec_in: (BATCH, 1, 1)
        # First Init as EoS
        dec_in = x_start

        for t in range(out.size(1)):
            # 获取解码结果
            hs, att_w = self.decoder(dec_in, hs, out)
            # 观测值
            predictions = F.softmax(att_w, dim=1).argmax(1)

            # Pick next index
            # If teacher force the next element will we the ground truth
            # otherwise will be the predicted value at current timestep
            # 是否指导学习
            import random
            teacher_force = random.random() < teacher_force_ratio
            idx = y[:, t] if teacher_force else predictions
            # 获取下一个元素预测结果
            dec_in = torch.stack([x[b, idx[b].item()] for b in range(x.size(0))])
            dec_in = dec_in.view(out.size(0), 1, -1).type(torch.float)

            # 损失函数
            loss += F.cross_entropy(att_w, y[:, t])
            outputs[t] = predictions
        # Weight losses, so every element in the batch
        # has the same 'importance'
        batch_loss = loss / y.size(0)

        return outputs, batch_loss


def run_pnn():
    '''
    mvp version code of pnn

    '''

    emb_dim = 8
    len_seq = 24
    hidden_size = 256
    attention_units = 256
    decoder = AttDecoder(hidden_size, emb_dim, attention_units=attention_units)
    encoder = LSTMEncoder(emb_dim, hidden_size)
    lr = 0.000001
    pnn = PointerNetwork(encoder, decoder)
    opt = torch.optim.Adam(pnn.parameters(), lr=lr)

    batch_size = 20
    train_data = get_raw_data(get_train_path())
    dataset = RawDataSet(train_data)
    dataloader = DataLoader(dataset=dataset, batch_size=batch_size)
    for idx, item in enumerate(dataloader):
        [x, y, seq_lens, starts] = item
        starts = starts.view(batch_size, 1, -1).contiguous()
        seq_lens = seq_lens.view(batch_size, 1).contiguous()
        pnn(x, y, starts, seq_lens)
    x = torch.rand(batch_size, len_seq, emb_dim)
    y = torch.randint(25, (batch_size, len_seq))
    start = torch.rand(batch_size, 1, emb_dim)
    batch_len = torch.zeros(batch_size, 1)
    batch_len = torch.fill_(batch_len, len_seq).long()
    pnn(x, y, start, batch_len)

    pass


if __name__ == '__main__':
    run_pnn()