here is my model
# coding: utf-8
import torch
import torch.nn as nn
class ScaleAndShift(nn.Module):
def __init__(self, vocab_size):
super().__init__()
self.scale = nn.Parameter(torch.zeros(vocab_size), requires_grad=True)
self.shift = nn.Parameter(torch.zeros(vocab_size), requires_grad=True)
def forward(self, x):
outputs = torch.exp(self.scale) * x + self.shift
return outputs
class SelfModulatedLayerNormalization(nn.Module):
def __init__(self, embed_dim: int, num_hidden: int):
super().__init__()
self.layer_norm = nn.LayerNorm(embed_dim)
self.beta_dense_1 = nn.Linear(num_hidden * 2, num_hidden // 2) # 乘以2是因为encoder是双向LSTM
self.beta_dense_2 = nn.Linear(num_hidden // 2, embed_dim)
self.gamma_dense_1 = nn.Linear(num_hidden * 2, num_hidden // 2)
self.gamma_dense_2 = nn.Linear(num_hidden // 2, embed_dim)
self.relu = nn.ReLU()
def forward(self, inputs):
inputs, cond = inputs
inputs = self.layer_norm(inputs) # inputs.shape
beta = self.beta_dense_1(cond)
beta = self.beta_dense_2(beta)
gamma = self.gamma_dense_1(cond)
gamma = self.gamma_dense_2(gamma)
for _ in range(len(inputs.shape) - len(cond.shape)):
beta = beta.unsqueeze(1)
gamma = gamma.unsqueeze(1)
return inputs * (gamma + 1) + beta
class BiDirectionalLSTM(nn.Module):
def __init__(self, input_size: int, hidden_size: int, num_layers: int):
super().__init__()
self.lstm = nn.LSTM(input_size, hidden_size=hidden_size, num_layers=num_layers, bidirectional=True, batch_first=True)
def forward(self, x):
encoder_outputs, _ = self.lstm(x)
return encoder_outputs
class MyMultiHeadAttention(nn.Module):
def __init__(self, input_size, num_heads, size_per_head):
super().__init__()
self.q_dense = nn.Linear(input_size, size_per_head * num_heads)
self.k_dense = nn.Linear(input_size, size_per_head * num_heads)
self.v_dense = nn.Linear(input_size, size_per_head * num_heads)
self.multi_head_attn = nn.MultiheadAttention(size_per_head * num_heads, num_heads)
def forward(self, q, k, v):
q = self.q_dense(q)
k = self.k_dense(k)
v = self.v_dense(v)
outputs, _ = self.multi_head_attn(q.permute(1, 0, 2), k.permute(1, 0, 2), v.permute(1, 0, 2)) # [seg_lem, batch_size, size_per_head * num_heads]
return outputs.permute(1, 0, 2)
class Encoder(nn.Module):
def __init__(self, embed_dim: int, num_layers: int, hidden_size: int):
super().__init__()
self.lstm_1 = BiDirectionalLSTM(input_size=embed_dim, hidden_size=hidden_size, num_layers=num_layers)
self.layer_norm_1 = nn.LayerNorm(embed_dim)
self.lstm_2 = BiDirectionalLSTM(input_size=hidden_size * 2, hidden_size=hidden_size, num_layers=num_layers)
self.layer_norm_2 = nn.LayerNorm(hidden_size * 2)
def forward(self, x, x_mask):
# x: [batch_size, max_len, embed_size]
# x_mask: [batch_size, max_len, 1]
x = self.layer_norm_1(x) # [batch_size, max_len, embed_size]
x = self.lstm_1(x) # [batch_size, max_len, (num_direction * hidden_size)]
x = self.layer_norm_2(x) # [batch_size, max_len, (num_direction * hidden_size)]
x = self.lstm_2(x) # [batch_size, max_len, (num_direction * hidden_size)]
x_max = seq_maxpool([x, x_mask]) # [batch_size, (num_direction * hidden_size)]
return x, x_max
class Decoder(nn.Module):
def __init__(self, embed_dim: int, hidden_size: int, num_layers: int):
super().__init__()
self.lstm_1 = nn.LSTM(embed_dim, hidden_size=hidden_size, num_layers=num_layers, bidirectional=False, batch_first=True)
self.layer_norm_1 = SelfModulatedLayerNormalization(embed_dim, hidden_size // 2) # 除以2是为了与encoder的hidden size相等
self.lstm_2 = nn.LSTM(hidden_size, hidden_size=hidden_size, num_layers=num_layers, bidirectional=False, batch_first=True)
self.layer_norm_2 = SelfModulatedLayerNormalization(hidden_size, hidden_size // 2)
self.layer_norm_3 = SelfModulatedLayerNormalization(hidden_size, hidden_size // 2)
def forward(self, y, x_max):
# y: [batch_size, max_len, embed_size]
# x_max: [batch_size, (num_direction * hidden_size)] hidden_size from encoder
y = self.layer_norm_1([y, x_max]) # [batch_size, max_len, embed_size]
y, _ = self.lstm_1(y) # [batch_size, max_len, hidden_size)]
y = self.layer_norm_2([y, x_max]) # [batch_size, max_len, hidden_size]
y, _ = self.lstm_2(y) # [batch_size, max_len, hidden_size]
output = self.layer_norm_3([y, x_max]) # [batch_size, max_len, hidden_size]
return output
class Seq2Seq(nn.Module):
def __init__(self, vocab_size: int, embed_dim: int, num_layers: int, hidden_size: int, num_heads: int, size_per_head: int):
super().__init__()
self.embed_x = nn.Embedding(vocab_size, embedding_dim=embed_dim)
self.embed_y = nn.Embedding(vocab_size, embedding_dim=embed_dim)
self.encoder = Encoder(embed_dim, num_layers, hidden_size // 2) # 除以2是因为encoder是双向LSTM
self.decoder = Decoder(embed_dim, hidden_size, num_layers)
self.attn = MyMultiHeadAttention(hidden_size, num_heads, size_per_head)
self.dense_1 = nn.Linear(size_per_head * num_heads + hidden_size, embed_dim)
self.relu = nn.LeakyReLU(0.2)
self.dense_2 = nn.Linear(embed_dim, vocab_size)
self.scale_shift = ScaleAndShift(vocab_size)
def forward(self, x, x_mask, x_one_hot, y):
# x: [batch_size, x_max_len]
# x_mask: [batch_size, x_max_len, 1]
# x_one_hot: [batch_size, vocab_size ]
# y: [batch_size, y_max_len]
x = self.embed_x(x)
y = self.embed_y(y)
x, x_max = self.encoder(x, x_mask) # [batch_size, x_max_len, hidden_size], [batch_size, hidden_size]
y = self.decoder(y, x_max) # [batch_size, y_max_len, hidden_size]
xy = self.attn(y, x, x) # [batch_size, y_max_len, size_per_head * num_heads]
xy1 = torch.cat((y, xy), dim=len(y.shape) - 1) # [batch_size, y_max_len, size_per_head * num_heads + hidden_size]
xy1 = self.dense_1(xy1) # [batch_size, y_max_len, embed_dim]
xy1 = self.relu(xy1) # [batch_size, y_max_len, embed_dim]
xy1 = self.dense_2(xy1) # [batch_size, y_max_len, vocab_size ]
x_pior = self.scale_shift(x_one_hot) # [batch_size, vocab_size]
output = (xy1 + x_pior.unsqueeze(1)) / 2 # [batch_size, y_max_len, vocab_size]
return output
def seq_maxpool(x):
seq, mask = x
seq -= (1 - mask) * 1e10
return torch.max(seq, 1).values
def seq_avgpool(x):
seq, mask = x
return torch.sum(seq * mask, 1) / (torch.sum(mask, 1) + 1e-6)