Hi there,
So I followed this tutorial to implement the transformer architecture from the “Attention Is All You Need” paper. I had to change the code in the tutorial a bit as it had some mistakes. I am using this model for a Neural Machine Translation task but my loss isn’t decreasing and is always staying within the range of 5 - 5.7. My input and target tensors are in the form of (batch_size, seq_len). Here is an example of a toy dataset:
src = torch.tensor([[1, 5, 6, 4, 3, 9, 5, 2, 0], [1, 8, 7, 3, 4, 5, 6, 7, 2]])
trg = torch.tensor([[1, 7, 4, 3, 5, 9, 2, 0], [1, 5, 6, 2, 4, 7, 6, 2]])
where 1 & 2 are the <sos>
& <eos>
tokens respectively and 0 is the <pad>
token.
Here is my code for the transformer:
class AttentionHead(nn.Module):
def __init__(self, emb_dim, dim_kqv):
super(AttentionHead, self).__init__()
self.dim_kqv = dim_kqv
self.wq = nn.Linear(emb_dim, dim_kqv)
self.wk = nn.Linear(emb_dim, dim_kqv)
self.wv = nn.Linear(emb_dim, dim_kqv)
def forward(self, q, k, v, mask):
queries = self.wq(q)
keys = self.wk(k)
values = self.wv(v)
score = queries.bmm(keys.transpose(1, 2))
score = torch.div(score, self.dim_kqv ** 0.5, rounding_mode='floor')
if mask is not None:
score = score.masked_fill(mask == 0, -1e9)
softmax = F.softmax(score, dim = -1)
return softmax.bmm(values)
class MultiHeadAttention(nn.Module):
def __init__(self, num_heads, emb_dim, dim_kqv):
super(MultiHeadAttention, self).__init__()
self.heads = nn.ModuleList(
[AttentionHead(emb_dim, dim_kqv) for _ in range(num_heads)]
)
self.w0 = nn.Linear(num_heads * dim_kqv, emb_dim)
def forward(self, q, k, v, mask):
attentions = [h(q, k, v, mask) for h in self.heads]
attentions = torch.cat(attentions, dim = -1)
out = self.w0(attentions)
return out
class Residual(nn.Module):
def __init__(self, sublayer, dimension, dropout):
super(Residual, self).__init__()
self.sublayer = sublayer
self.norm = nn.LayerNorm(dimension)
self.dropout = nn.Dropout(dropout)
def forward(self, *tensors):
return self.dropout(self.norm(tensors[0] + self.sublayer(*tensors)))
class FeedForward(nn.Module):
def __init__(self, emb_dim, ff_dim):
super(FeedForward, self).__init__()
self.network = nn.Sequential(
nn.Linear(emb_dim, ff_dim),
nn.ReLU(),
nn.Linear(ff_dim, emb_dim)
)
def forward(self, residual_out):
return self.network(residual_out)
class EncoderLayer(nn.Module):
def __init__(self, emb_dim, num_heads, ff_dim, dropout):
super(EncoderLayer, self).__init__()
self.dim_kqv = emb_dim // num_heads
assert (self.dim_kqv * num_heads == emb_dim), "Embedding size must be divisible by number of heads"
self.attention = Residual(
MultiHeadAttention(num_heads, emb_dim, self.dim_kqv),
dimension=emb_dim,
dropout=dropout,
)
self.feed_forward = Residual(
FeedForward(emb_dim, ff_dim),
dimension=emb_dim,
dropout=dropout,
)
def forward(self, src, mask):
src = self.attention(src, src, src, mask)
out = self.feed_forward(src)
return out
class Encoder(nn.Module):
def __init__(self,
emb_dim,
num_heads,
ff_dim,
num_layers,
src_vocab_size,
padding_index,
dropout):
super(Encoder, self).__init__()
self.layers = nn.ModuleList([
EncoderLayer(emb_dim,
num_heads,
ff_dim,
dropout)
for _ in range(num_layers)
])
self.embedding = nn.Embedding(src_vocab_size, emb_dim, padding_idx=0)
self.pe = PositionalEncoder(emb_dim)
def forward(self, src):
src = self.embedding(src)
src = self.pe(src)
for layer in self.layers:
src = layer(src, None)
return src
class DecoderLayer(nn.Module):
def __init__(self, emb_dim, num_heads, ff_dim, dropout):
super(DecoderLayer, self).__init__()
self.dim_kqv = emb_dim // num_heads
assert (self.dim_kqv * num_heads == emb_dim), "Embedding size must be divisible by number of heads"
self.attention_1 = Residual(
MultiHeadAttention(num_heads, emb_dim, self.dim_kqv),
dimension=emb_dim,
dropout=dropout
)
self.attention_2 = Residual(
MultiHeadAttention(num_heads, emb_dim, self.dim_kqv),
dimension=emb_dim,
dropout=dropout
)
self.feed_forward = Residual(
FeedForward(emb_dim, ff_dim),
dimension=emb_dim,
dropout=dropout
)
def forward(self, trg, memory, mask):
query = self.attention_1(trg, trg, trg, mask)
attentions = self.attention_2(query, memory, memory, None)
out = self.feed_forward(attentions)
return out
class Decoder(nn.Module):
def __init__(self,
emb_dim,
num_heads,
ff_dim,
num_layers,
out_size,
padding_index,
dropout):
super(Decoder, self).__init__()
self.layers = nn.ModuleList([
DecoderLayer(emb_dim, num_heads, ff_dim, dropout)
for _ in range(num_layers)
])
self.embedding = nn.Embedding(out_size, emb_dim, padding_idx=padding_index)
self.pe = PositionalEncoder(emb_dim)
def make_trg_mask(self, trg):
batch_size, seq_len = trg.shape[0], trg.shape[1]
mask = torch.tril(torch.ones(batch_size, seq_len, seq_len))
return mask
def forward(self, trg, encoder_out):
trg = self.embedding(trg)
trg = self.pe(trg)
mask = self.make_trg_mask(trg).to(trg.get_device())
for layer in self.layers:
trg = layer(trg, encoder_out, mask)
# return self.lin(trg)
return trg
class VanillaTransformer(nn.Module):
def __init__(self,
emb_dim,
num_heads,
ff_dim,
num_layers,
src_vocab_size,
trg_vocab_size,
device,
padding_index,
dropout):
super(VanillaTransformer, self).__init__()
self.encoder = Encoder(emb_dim,
num_heads,
ff_dim,
num_layers,
src_vocab_size,
padding_index,
dropout).to(device)
self.decoder = Decoder(emb_dim,
num_heads,
ff_dim,
num_layers,
trg_vocab_size,
padding_index,
dropout).to(device)
self.lin = nn.Linear(emb_dim, trg_vocab_size)
def forward(self, src, trg):
encoder_out = self.encoder(src)
decoder_out = self.decoder(trg, encoder_out)
out = self.lin(decoder_out)
return out
And here is the training loop:
model = VanillaTransformer(EMBEDDING_DIM = 256,
NUM_HEADS = 8,
FF_DIM = 2048,
NUM_LAYERS = 6,
SOURCE_VOCAB_SIZE = 10,
TARGET_VOCAB_SIZE=10,
device,
padding_index = 0,
DROPOUT= 0.1).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
loss_func = nn.CrossEntropyLoss(ignore_index=padding_index)
for epoch in tqdm(range(NUM_EPOCHS), position = 0, leave = True):
running_loss = 0.0
model.train()
for batch_index, batch in enumerate(train_loader):
optimizer.zero_grad()
src = batch['x_source']
trg = batch['y_target']
y_pred = model(src.to(device), trg[:, :-1].to(device))
y_pred = y_pred.reshape(-1, y_pred.shape[2])
loss = loss_func(y_pred, trg[:, 1:].reshape(-1).to(device))
loss.backward()
running_loss += (loss.item() - running_loss) / (batch_index + 1)
optimizer.step()
if epoch == 0 or (epoch + 1) % PRINT_EVERY == 0:
print('Epoch: {:<2} Train loss: {:0.4f}'.format(epoch + 1 , running_loss))
I would also appreciate it if you could let me know if the way I am sending the input to the decoder and how its being compared to targets is correct. I understand that the amount of code here could be overwhelming therefore, I would suggest first reading the article in the link above. Its a great read and everything there is clearly explained and it will make understanding my code much easier.
Been at this for 2 days now so any help is appreciated. Thanks