I have a Encoder-Decoder model
import torch
import torch.nn as nn
import torch.nn.functional as F
class Encoder(nn.Module):
def __init__(self, input_size, hidden_size, num_layers=1, dropout=0.):
super(Encoder, self).__init__()
self.num_layers = num_layers
self.rnn = nn.GRU(input_size, hidden_size, num_layers,
batch_first=True, bidirectional=True, dropout=dropout)
def forward(self, x, mask, lengths):
packed = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True)
output, final = self.rnn(packed)
output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=True)
fwd_final = final[0:final.size(0):2]
bwd_final = final[1:final.size(0):2]
final = torch.cat([fwd_final, bwd_final], dim=2)
return output, final
class BahdanauAttention(nn.Module):
def __init__(self, hidden_size, key_size=None, query_size=None):
super(BahdanauAttention, self).__init__()
key_size = 2 * hidden_size if key_size is None else key_size
query_size = hidden_size if query_size is None else query_size
self.key_layer = nn.Linear(key_size, hidden_size, bias=False)
self.query_layer = nn.Linear(query_size, hidden_size, bias=False)
self.energy_layer = nn.Linear(hidden_size, 1, bias=False)
def forward(self, query=None, proj_key=None, value=None, mask=None):
query = self.query_layer(query)
scores = self.energy_layer(torch.tanh(query + proj_key))
scores = scores.squeeze(2).unsqueeze(1)
scores.data.masked_fill_(mask == 0, -float('inf'))
alphas = F.softmax(scores, dim=-1)
context = torch.bmm(alphas, value)
return context, alphas
class Generator(nn.Module):
"""Define standard linear + softmax generation step."""
def __init__(self, hidden_size, vocab_size):
super(Generator, self).__init__()
self.proj = nn.Linear(hidden_size, vocab_size, bias=False)
def forward(self, x):
return F.log_softmax(self.proj(x), dim=-1)
class Decoder(nn.Module):
def __init__(self, emb_size, hidden_size, attention, num_layers=1, dropout=0.5, bridge=True):
super(Decoder, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.attention = attention
self.dropout = dropout
self.rnn = nn.GRU(emb_size + 2*hidden_size, hidden_size, num_layers,
batch_first=True, dropout=dropout)
self.bridge = nn.Linear(2*hidden_size, hidden_size, bias=True) if bridge else None
self.dropout_layer = nn.Dropout(p=dropout)
self.pre_output_layer = nn.Linear(hidden_size + 2*hidden_size + emb_size,
hidden_size, bias=False)
def forward_step(self, prev_embed, encoder_hidden, src_mask, proj_key, hidden):
query = hidden[-1].unsqueeze(1)
context, attn_probs = self.attention(
query=query, proj_key=proj_key,
value=encoder_hidden, mask=src_mask)
rnn_input = torch.cat([prev_embed, context], dim=2)
output, hidden = self.rnn(rnn_input, hidden)
pre_output = torch.cat([prev_embed, output, context], dim=2)
pre_output = self.dropout_layer(pre_output)
pre_output = self.pre_output_layer(pre_output)
return output, hidden, pre_output
def forward(self, trg_embed, encoder_hidden, encoder_final,
src_mask, trg_mask, hidden=None, max_len=None):
if max_len is None:
max_len = trg_mask.size(-1)
if hidden is None:
hidden = self.init_hidden(encoder_final)
proj_key = self.attention.key_layer(encoder_hidden)
decoder_states = []
pre_output_vectors = []
for i in range(max_len):
prev_embed = trg_embed[:, i].unsqueeze(1)
output, hidden, pre_output = self.forward_step(
prev_embed, encoder_hidden, src_mask, proj_key, hidden)
decoder_states.append(output)
pre_output_vectors.append(pre_output)
decoder_states = torch.cat(decoder_states, dim=1)
pre_output_vectors = torch.cat(pre_output_vectors, dim=1)
return decoder_states, hidden, pre_output_vectors
def init_hidden(self, encoder_final):
if encoder_final is None:
return None
return torch.tanh(self.bridge(encoder_final))
class EncoderDecoder(nn.Module):
def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
super(EncoderDecoder, self).__init__()
self.encoder = encoder
self.decoder = decoder
self.src_embed = src_embed
self.tgt_embed = tgt_embed
self.generator = generator
def forward(self, src, trg, src_mask, trg_mask, src_lengths, trg_lengths):
encoder_hidden, encoder_final = self.encode(src, src_mask, src_lengths)
return self.decode(encoder_hidden, encoder_final, src_mask, trg, trg_mask)
def encode(self, src, src_mask, src_lengths):
return self.encoder(self.src_embed(src), src_mask, src_lengths)
def decode(self, encoder_hidden, encoder_final, src_mask, trg, trg_mask,
decoder_hidden=None):
return self.decoder(self.tgt_embed(trg), encoder_hidden, encoder_final,
src_mask, trg_mask, hidden=decoder_hidden)
def make_model(src_vocab, tgt_vocab, emb_size=256, hidden_size=512, num_layers=1, dropout=0.1):
"Helper: Construct a model from hyperparameters."
attention = BahdanauAttention(hidden_size)
model = EncoderDecoder(
Encoder(emb_size, hidden_size, num_layers=num_layers, dropout=dropout),
Decoder(emb_size, hidden_size, attention, num_layers=num_layers, dropout=dropout),
nn.Embedding(src_vocab, emb_size),
nn.Embedding(tgt_vocab, emb_size),
Generator(hidden_size, tgt_vocab))
return model.cuda() if torch.cuda.is_available() else model
data preparing:
import os
from pathlib import Path
from collections import Counter
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
def build_vocab(file, size):
with open(file, 'r') as f:
words = f.read().split()
# Count the frequency of each word
word_counts = Counter(words)
# Get the most common words
common_words = word_counts.most_common(size)
# Create the vocab
vocab = {word: i+4 for i, (word, _) in enumerate(common_words)}
# Add special tokens
vocab['<pad>'] = 0
vocab['<unk>'] = 1
vocab['<start>'] = 2
vocab['<eos>'] = 3
return vocab
class CommitMessageDataset(Dataset):
def __init__(self, diff_file, msg_file, diff_vocab, msg_vocab):
self.diff_file = diff_file
self.msg_file = msg_file
self.diff_vocab = diff_vocab
self.msg_vocab = msg_vocab
# Read the files
with open(diff_file, 'r') as f:
self.diffs = f.readlines()
with open(msg_file, 'r') as f:
self.msgs = f.readlines()
def __len__(self):
return len(self.diffs)
def __getitem__(self, idx):
diff = ['<start>'] + self.diffs[idx].strip().split() + ['<eos>']
msg = ['<start>'] + self.msgs[idx].strip().split() + ['<eos>']
# Convert words to indices
diff = [self.diff_vocab[word] if word in self.diff_vocab else self.diff_vocab['<unk>'] for word in diff]
msg = [self.msg_vocab[word] if word in self.msg_vocab else self.msg_vocab['<unk>'] for word in msg]
return torch.tensor(diff), torch.tensor(msg)
def collate_fn(batch):
diffs, msgs = zip(*batch)
diffs = pad_sequence(diffs, batch_first=True, padding_value=0)
msgs = pad_sequence(msgs, batch_first=True, padding_value=0)
return diffs, msgs
current_path = Path(__file__).parent.absolute()
# Build the vocabularies
msg_vocab = build_vocab(os.path.join(current_path, 'data/cleaned.train.msg'), 16000)
diff_vocab = build_vocab(os.path.join(current_path, 'data/cleaned.train.diff'), 50000)
msg_vocab_itos = dict({(v,k) for k,v in msg_vocab.items()})
diff_vocab_itos = dict({(v,k) for k,v in diff_vocab.items()})
# Create the dataset and dataloader
train_dataset = CommitMessageDataset(os.path.join(current_path, 'data/cleaned.train.diff'), os.path.join(current_path, 'data/cleaned.train.msg'), diff_vocab, msg_vocab)
valid_dataset = CommitMessageDataset(os.path.join(current_path, 'data/cleaned.valid.diff'), os.path.join(current_path, 'data/cleaned.valid.msg'), diff_vocab, msg_vocab)
test_dataset = CommitMessageDataset(os.path.join(current_path, 'data/cleaned.test.diff'), os.path.join(current_path, 'data/cleaned.test.msg'), diff_vocab, msg_vocab)
train_dataloader = DataLoader(train_dataset, batch_size=80, shuffle=True, collate_fn=collate_fn)
valid_dataloader = DataLoader(valid_dataset, batch_size=80, shuffle=False, collate_fn=collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=80, shuffle=False, collate_fn=collate_fn)
And the training loop:
from model import Encoder, Decoder, BahdanauAttention, make_model
from dataset_buider2 import msg_vocab, diff_vocab, train_dataloader, test_dataloader, valid_dataloader, diff_vocab_itos, msg_vocab_itos
from torch.optim import Adadelta
from torchtext.data.metrics import bleu_score
import numpy as np
import torch
import os
import logging
from pathlib import Path
current_path = Path.cwd()
logging.basicConfig(filename='log.txt',
filemode='a',
format='%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s',
datefmt='%H:%M:%S',
level=logging.DEBUG)
model_dir = os.path.join(current_path, 'models')
saved_models = [os.path.join(model_dir, f) for f in os.listdir(model_dir) if f.endswith('.pth')]
print(saved_models)
model = make_model(len(diff_vocab), len(msg_vocab), emb_size=512, hidden_size=1024, num_layers=6, dropout=0.1)
criterion = torch.nn.CrossEntropyLoss(ignore_index=diff_vocab['<pad>']) # игнорирование <pad>
optimizer = Adadelta(model.parameters())
model_dir = os.path.join(current_path, 'models')
if saved_models:
latest_model = max(saved_models, key=os.path.getctime)
print(f"Loading weights from {latest_model}...")
checkpoint = torch.load(os.path.join(model_dir, latest_model))
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch'] + 1
else:
print("No saved models found, initializing a new model...")
start_epoch = 0
num_epochs = 5000
print_every = 50
best_bleu = 0
batch_count = 0
if torch.cuda.is_available():
model = model.to(device='cuda')
for epoch in range(start_epoch, num_epochs):
model.train()
total_loss = 0
for i, (diffs, msgs) in enumerate(train_dataloader):
if torch.cuda.is_available():
diffs = diffs.to(device='cuda')
msgs = msgs.to(device='cuda')
batch_count += 1
src_mask = (diffs != diff_vocab['<pad>']).unsqueeze(-2)
trg_mask = (msgs != msg_vocab['<pad>']).unsqueeze(-2)
src_lengths = torch.sum(src_mask.squeeze(-2), dim=-1).cpu()
trg_lengths = torch.sum(trg_mask.squeeze(-2), dim=-1).cpu()
lengths, indices = torch.sort(src_lengths, descending=True)
diffs = diffs[indices]
msgs = msgs[indices]
src_mask = src_mask[indices]
trg_mask = trg_mask[indices]
src_lengths = src_lengths[indices]
trg_lengths = trg_lengths[indices]
outputs, _, _ = model(diffs, msgs, src_mask, trg_mask, src_lengths, trg_lengths)
outputs = model.generator(outputs)
loss = criterion(outputs.view(-1, len(msg_vocab)), msgs.view(-1))
total_loss += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (i + 1) % print_every == 0:
print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_dataloader)}], Loss: {total_loss/print_every}')
total_loss = 0
if batch_count % 2_000 == 0:
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict()
}, os.path.join(model_dir, f'model_{epoch}_{i}.pth'))
model.eval()
with torch.inference_mode():
candidate_corpus = []
reference_corpus = []
for diffs, msgs in valid_dataloader:
if torch.cuda.is_available():
diffs = diffs.to(device='cuda')
msgs = msgs.to(device='cuda')
src_mask = (diffs != diff_vocab['<pad>']).unsqueeze(-2)
trg_mask = (msgs != msg_vocab['<pad>']).unsqueeze(-2)
src_lengths = torch.sum(src_mask.squeeze(-2), dim=-1).cpu()
trg_lengths = torch.sum(trg_mask.squeeze(-2), dim=-1).cpu()
lengths, indices = torch.sort(src_lengths, descending=True)
diffs = diffs[indices]
msgs = msgs[indices]
src_mask = src_mask[indices]
trg_mask = trg_mask[indices]
src_lengths = src_lengths[indices]
trg_lengths = trg_lengths[indices]
outputs, _, _ = model(diffs, msgs, src_mask, trg_mask, src_lengths, trg_lengths)
outputs = outputs.argmax(dim=-1).cpu().numpy().tolist()
for output, msg in zip(outputs, msgs):
output = output[1:output.index(msg_vocab['<eos>']) if msg_vocab['<eos>'] in output else None]
msg = msg[1:msg.tolist().index(msg_vocab['<eos>']) if msg_vocab['<eos>'] in msg.tolist() else None]
output = [msg_vocab_itos[i] for i in output]
msg = [msg_vocab_itos[i] for i in msg.tolist()]
candidate_corpus.append(output)
reference_corpus.append([msg])
logging.info(f"""
Source message: {" ".join(msg)}\n
Generated message: {" ".join(output)}\n
{len(output)}
""")
bleu = bleu_score(candidate_corpus, reference_corpus)
if bleu > best_bleu:
best_bleu = bleu
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict()
}, os.path.join(model_dir, f'best_model.pth'))
print(f'Epoch [{epoch+1}/{num_epochs}], Bleu score: {bleu*100:.2f}%')
However,the BLEU at the first batch decreases to zero and stay at this level during all the training. At the same time loss is drastically decreasing. What may be the problem?