I have this Transformer here:
self.src_word_embed = nn.Embedding(num_embeddings=num_words, embedding_dim=dim_model)
self.pos_embed = PositionalEncoding(dim_model=dim_model, max_len=max_seq_len, dropout=DROPOUT)
self.tgt_word_embed = nn.Embedding(num_embeddings=num_words, embedding_dim=dim_model)
self.transformer = nn.Transformer(d_model=dim_model, nhead=heads, num_encoder_layers=num_layers, num_decoder_layers=num_layers, dim_feedforward=dim_feedforward, dropout=dropout, batch_first=True)
self.out = nn.Linear(dim_model, num_words)
When training, I use
outputs = model(src_seq=sec_seq, tgt_seq=tgt_seq, tgt_mask=tgt_mask)
# size = [batch_size, tgt_len, num_words]
The training loss rapidly down to 1e-5, however when using greedy decoding to inference, the result isn’t that good. So how could I use greedy decoding to train the model where the torch.argmax() used in greedy_decoding won’t keep grad_fn
I have my own implement here but it’s so slow!!!
if use_teacher_forcing:
tgt_mask = nn.Transformer.generate_square_subsequent_mask(sz=tgt_len, device=device)
# size = [batch_size, max_tgt_seq_len]
outputs = model(src_seq=src_seq, tgt_seq=tgt_seq, tgt_mask=tgt_mask)
else:
for iii in range(inputs.size(0)):
# Get encoder output
src_seq = torch.unsqueeze(inputs[iii, :], 0)
src_key_padding_mask = (src_seq == PAD_TOKEN).to(device)
src_word_embed = model.src_word_embed(src_seq)
src = model.pos_embed(src_word_embed)
enc_outputs = model.transformer.encoder(src=src, src_key_padding_mask=src_key_padding_mask)
# Initialize decoder output
dec_result = torch.Tensor([[START_TOKEN]]).to(torch.int64).to(device)
for iiii in range(tgt_len):
tgt_word_embed = model.tgt_word_embed(dec_result)
dec_input = model.pos_embed(tgt_word_embed)
dec_outputs = model.transformer.decoder(tgt=dec_input, memory=enc_outputs)
projected = model.out(dec_outputs)
# size = [1, dec_result.size(0), num_words]
prob = F.softmax(torch.squeeze(projected, 0), dim=-1)
idx = torch.argmax(prob, dim=-1)
next_symbol = idx[-1]
dec_result = torch.cat([dec_result, torch.Tensor([[next_symbol]]).to(src_seq.dtype).to(device)], -1)
if iiii == 0:
output = projected[:, -1, :]
else:
output = torch.cat((output, projected[:, -1, :]), 0)
if iii == 0:
outputs = torch.unsqueeze(output, 0)
else:
outputs = torch.cat((outputs, torch.unsqueeze(output, 0)), 0)
loss = loss_fn(outputs.view(-1, outputs.size(-1)), tgt_y.view(-1)) # CrossEntropy
loss.backward()
optimizer.step()
Is there any methods to keep grad_fn after argmax or to prevent that many times of torch.cat?