this is a pointer net model implements by me,
in one mini-batch,there are diff_len seqences ,and this sequences paddding to a max length
how to deal these seq to decoder, i don’t kown the padding effect on decoder in backward?
suppose that in decoder time t1, seq_1,seq_2 in a mini-batch,
seq_1 is stop in last time,and seq_2 need to calc
i select the seq_2 to calc loss, drop seq_1
is this right?
from typing import Tuple
import torch.nn.functional as F
import torch
import torch.nn as nn
class LSTMEncoder(nn.Module):
def __init__(self, embedding_dim, hidden_size, num_layers=1, batch_first=True, bidirectional=False):
super(LSTMEncoder, self).__init__()
self.batch_first = batch_first
self.bidirectional = bidirectional
self.num_layers = num_layers
self.embedding_dim = embedding_dim
self.num_directions = 2 if self.bidirectional else 1
self.hidden_size = int(hidden_size / self.num_directions)
self.rnn = nn.LSTM(input_size=embedding_dim, hidden_size=self.hidden_size, num_layers=num_layers,
batch_first=batch_first, bidirectional=bidirectional) # nn.LSTM(512, 256, 3)
def forward(self, embedded_inputs, input_lengths,
max_len): # embedded_inputs:(64, 25, 512), input_lengths:(64,), max_len:25
# Pack padded batch of sequences for RNN module
packed = nn.utils.rnn.pack_padded_sequence(embedded_inputs, input_lengths.view(-1).cpu(),
batch_first=self.batch_first,
enforce_sorted=False) # 记录需要mask的部分,有效len之外的部分不经过rnn
# Forward pass through RNN
outputs, hidden = self.rnn(packed) # hidden:(6, 64, 256)
# Unpack padding
outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=self.batch_first) # outputs:(64, 22, 512)
# Unpack函数只能padding至当前batch最大长度,需继续pad至全局最大长度
extra_padding_size = max_len - outputs.shape[1]
outputs = nn.functional.pad(outputs, [0, 0, 0, extra_padding_size, 0, 0], mode="constant", value=0) # 第2维进行下填充
# Return output and final hidden state
# if self.bidirectional:
# # Optionally, Sum bidirectional RNN outputs
# # outputs = outputs[:, :, :self.hidden_size] + outputs[:, :, self.hidden_size:]
# outputs = torch.cat((outputs[:, :, :self.hidden_size], outputs[:, :, self.hidden_size:]), dim=2)
# batch_size = embedded_inputs.size(0)
# h_n, c_n = hidden
# h_n = h_n.view(self.num_layers, self.num_directions, batch_size, self.hidden_size)
# c_n = c_n.view(self.num_layers, self.num_directions, batch_size, self.hidden_size)
# if self.bidirectional:
# f = (h_n[-1, 0, :, :].squeeze(), c_n[-1, 0, :, :].squeeze()) # f: tuple:2, ((64, 256), (64, 256))
# b = (h_n[-1, 1, :, :].squeeze(), c_n[-1, 1, :, :].squeeze()) # b: tuple:2, ((64, 256), (64, 256))
# hidden = (torch.cat((f[0], b[0]), dim=1), torch.cat((f[1], b[1]), dim=1)) # tuple:2, ((64, 512), (64, 512))
# else:
# hidden = (h_n[-1, 0, :, :].squeeze(), c_n[-1, 0, :, :].squeeze())
return outputs, hidden
class Attention(nn.Module):
def __init__(self, hidden_size, units):
super(Attention, self).__init__()
self.W1 = nn.Linear(hidden_size, units, bias=False)
self.W2 = nn.Linear(hidden_size, units, bias=False)
self.V = nn.Linear(units, 1, bias=False)
def forward(self,
encoder_out: torch.Tensor,
decoder_hidden: torch.Tensor):
# encoder_out: (BATCH, ARRAY_LEN, HIDDEN_SIZE)
# decoder_hidden: (BATCH, HIDDEN_SIZE)
# Add time axis to decoder hidden state
# in order to make operations compatible with encoder_out
# decoder_hidden_time: (BATCH, 1, HIDDEN_SIZE)
decoder_hidden_time = decoder_hidden.unsqueeze(1)
# uj: (BATCH, ARRAY_LEN, ATTENTION_UNITS)
# Note: we can add the both linear outputs thanks to broadcasting
uj = self.W1(encoder_out) + self.W2(decoder_hidden_time)
uj = torch.tanh(uj)
# uj: (BATCH, ARRAY_LEN, 1)
uj = self.V(uj)
# Attention mask over inputs
# aj: (BATCH, ARRAY_LEN, 1)
# aj = F.softmax(uj, dim=1)
#
# # di_prime: (BATCH, HIDDEN_SIZE)
# di_prime = aj * encoder_out
#
# di_prime = di_prime.sum(1)
return uj.squeeze(-1), uj.squeeze(-1)
# Attention weights over inputs: torch.Size([4, 6])
class RawDecoder(nn.Module):
def __init__(self,
hidden_size: int,
attention_units: int = 10):
super(RawDecoder, self).__init__()
self.lstm = nn.LSTM(hidden_size + 1, hidden_size, batch_first=True)
self.attention = Attention(hidden_size, attention_units)
def forward(self,
x: torch.Tensor,
hidden: Tuple[torch.Tensor],
encoder_out: torch.Tensor):
# x: (BATCH, 1, 1)
# hidden: (1, BATCH, HIDDEN_SIZE)
# encoder_out: (BATCH, ARRAY_LEN, HIDDEN_SIZE)
# ht_ last timqe hidden state
ht = hidden[0][0] # ht: (BATCH, HIDDEN_SIZE) h_n
# di: Attention aware hidden state -> (BATCH, HIDDEN_SIZE)
di, att_w = self.attention(encoder_out, ht)
# Append attention aware hidden state to our input
# x: (BATCH, 1, 1 + HIDDEN_SIZE)
x = torch.cat([di.unsqueeze(1), x], dim=2)
# Generate the hidden state for next timestep
_, hidden = self.lstm(x, hidden)
return hidden, att_w
class AttDecoder(nn.Module):
def __init__(self,
hidden_size: int, emb_size: int,
attention_units: int = 10):
super(AttDecoder, self).__init__()
self.lstm = nn.LSTM(emb_size, hidden_size, batch_first=True)
self.attention = Attention(hidden_size, attention_units)
def forward(self,
x: torch.Tensor,
hidden: Tuple[torch.Tensor],
encoder_out: torch.Tensor):
_, hidden = self.lstm(x, hidden)
ht = hidden[0][0] # ht: (BATCH, HIDDEN_SIZE) h_n
# di: Attention aware hidden state -> (BATCH, HIDDEN_SIZE)
di, att_w = self.attention(encoder_out, ht)
# Append attention aware hidden state to our input
# x: (BATCH, 1, 1 + HIDDEN_SIZE)
# Generate the hidden state for next timestep
return hidden, att_w
class PointerNetwork(nn.Module):
def __init__(self,
encoder: nn.Module,
decoder: nn.Module, max_len: int = 25):
super(PointerNetwork, self).__init__()
self.encoder = encoder
self.decoder = decoder
self.max_len = max_len
def forward(self,
x: torch.Tensor,
y: torch.Tensor, x_start: torch.Tensor,
batch_len: torch.Tensor,
teacher_force_ratio=.5):
max_len = torch.tensor([self.max_len]).long()
encoder_in = x # Batch L Hid[20, 25, 64])
# Batch,Len,Hidden
out, hs = self.encoder(encoder_in, batch_len, max_len)
loss = 0
# Len,Batch
outputs = torch.zeros(out.size(1), out.size(0), dtype=torch.long)
# First decoder input is always 0
# dec_in: (BATCH, 1, 1)
# First Init as EoS
dec_in = x_start
for t in range(out.size(1)):
# 获取解码结果
hs, att_w = self.decoder(dec_in, hs, out)
# 观测值
predictions = F.softmax(att_w, dim=1).argmax(1)
# Pick next index
# If teacher force the next element will we the ground truth
# otherwise will be the predicted value at current timestep
# 是否指导学习
import random
teacher_force = random.random() < teacher_force_ratio
idx = y[:, t] if teacher_force else predictions
# 获取下一个元素预测结果
dec_in = torch.stack([x[b, idx[b].item()] for b in range(x.size(0))])
dec_in = dec_in.view(out.size(0), 1, -1).type(torch.float)
# 损失函数
loss += F.cross_entropy(att_w, y[:, t])
outputs[t] = predictions
# Weight losses, so every element in the batch
# has the same 'importance'
batch_loss = loss / y.size(0)
return outputs, batch_loss
def run_pnn():
'''
mvp version code of pnn
'''
emb_dim = 8
len_seq = 24
hidden_size = 256
attention_units = 256
decoder = AttDecoder(hidden_size, emb_dim, attention_units=attention_units)
encoder = LSTMEncoder(emb_dim, hidden_size)
lr = 0.000001
pnn = PointerNetwork(encoder, decoder)
opt = torch.optim.Adam(pnn.parameters(), lr=lr)
batch_size = 20
train_data = get_raw_data(get_train_path())
dataset = RawDataSet(train_data)
dataloader = DataLoader(dataset=dataset, batch_size=batch_size)
for idx, item in enumerate(dataloader):
[x, y, seq_lens, starts] = item
starts = starts.view(batch_size, 1, -1).contiguous()
seq_lens = seq_lens.view(batch_size, 1).contiguous()
pnn(x, y, starts, seq_lens)
x = torch.rand(batch_size, len_seq, emb_dim)
y = torch.randint(25, (batch_size, len_seq))
start = torch.rand(batch_size, 1, emb_dim)
batch_len = torch.zeros(batch_size, 1)
batch_len = torch.fill_(batch_len, len_seq).long()
pnn(x, y, start, batch_len)
pass
if __name__ == '__main__':
run_pnn()