I have an NMT model trained as shown in the article below:
https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html
I’m trying to fine tune this model on the BLEU metric(Previously trained using the BCE loss) and have defined the training functions below:
Main Function:
def fine_tune(input_tensor, target_tensor, encoder, decoder, encoder_optimizer , decoder_optimizer, criterion, pair, max_length = MAX_LENGTH):
encoder_optimizer.zero_grad() # Zero out the gradients for proper parameter updation of the encoder.
decoder_optimizer.zero_grad() # Zero out the gradients for proper parameter updation of the decoder.
true_value = pair[1].split()
prediction_sentence = evaluate(encoder, decoder, pair[0])
prediction_sentence = prediction_sentence[0][:-1]
reward = sentence_bleu([true_value], prediction_sentence, smoothing_function = smoothie)
reward = torch.tensor(reward)
reward_n = reward*100
target = torch.tensor(100)
loss = criterion(reward_n, target)
loss.requires_grad = True
print(loss)
loss.backward()
for name, param in decoder.named_parameters():
if 'weight' in name:
print(param.grad)
encoder_optimizer.step()
decoder_optimizer.step()
return loss.item() / len(true_value)
Call:
def Fine_tune_trainIters(encoder, decoder, n_iters, print_every=1000, plot_every=100, learning_rate=0.01):
start = time.time()
plot_losses = []
print_loss_total = 0 # Reset every print_every
plot_loss_total = 0 # Reset every plot_every
encoder_optimizer = optim.SGD(encoder.parameters(), lr=learning_rate)
decoder_optimizer = optim.SGD(decoder.parameters(), lr=learning_rate)
# training_pairs = [tensorsFromPair(random.choice(train_pairs))
# for i in range(n_iters)]
training_pairs = [random.choice(train_pairs) for i in range(n_iters)]
criterion = nn.MSELoss()
error_count = 0
for iter in range(1, n_iters + 1):
training_pair = training_pairs[iter - 1]
input_tensor = training_pair[0]
target_tensor = training_pair[1]
loss = fine_tune(input_tensor, target_tensor, encoder,
decoder, encoder_optimizer, decoder_optimizer, criterion, training_pair)
print_loss_total += loss
plot_loss_total += loss
# except Exception:
# error_count += 1
if iter % print_every == 0:
print_loss_avg = print_loss_total / print_every
print_loss_total = 0
print('%s (%d %d%%) %.4f' % (timeSince(start, iter / n_iters),
iter, iter / n_iters * 100, print_loss_avg))
if iter % plot_every == 0:
plot_loss_avg = plot_loss_total / plot_every
plot_losses.append(plot_loss_avg)
plot_loss_total = 0
print(error_count)
# showPlot(plot_losses)
Prediction Function:
def evaluate(encoder, decoder, sentence, max_length=MAX_LENGTH):
input_tensor = tensorFromSentence(input_lang, sentence)
input_length = input_tensor.size()[0]
encoder_hidden = encoder.initHidden()
encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)
for ei in range(input_length):
encoder_output, encoder_hidden = encoder(input_tensor[ei],
encoder_hidden)
encoder_outputs[ei] += encoder_output[0, 0]
decoder_input = torch.tensor([[SOS_token]], device=device) # SOS
decoder_hidden = encoder_hidden
decoded_words = []
decoder_attentions = torch.zeros(max_length, max_length)
for di in range(max_length):
decoder_output, decoder_hidden, decoder_attention = decoder(
decoder_input, decoder_hidden, encoder_outputs)
decoder_attentions[di] = decoder_attention.data
topv, topi = decoder_output.data.topk(1)
if topi.item() == EOS_token:
decoded_words.append('<EOS>')
break
else:
decoded_words.append(output_lang.index2word[topi.item()])
decoder_input = topi.squeeze().detach()
return decoded_words, decoder_attentions[:di + 1]
However, on printing the gradients, they all come out to be zero, even if the loss is not.
Am I missing something here?
Additionally, I wanted to ask - is it possible to fine tune a model on some arbitrary score such as BLEU? If so, is this the right way?