Well I was right, I was indeed missing something very obvious. To anyone who comes after me and has a similar problem, the reason why my network was only copying results was because my training strategy was wrong. I was passing in targets to the decoder and calculating loss based on how similar what it produced was to those targets. If you think about it, I was asking the decoder to behave like an auto-encoder, to reproduce exactly what I passed in. That’s not very difficult for a transformer decoder to do, so it learned to copy very quickly, even with masks. Doing this also makes it impossible to perform inference, since the decoder never learned how to generate anything new.
How, you might ask, do you fix this? The solution for me was a couple steps:
- To add special start and end tokens to every target; e.g.
[ 'h', 'e', 'l', 'l', 'o']
became[ <start>, 'h', 'e', 'l', 'l', 'o', <end>]
(since it’s a character model, my start and end tokens are actually unicode tokens) - To add an additional loop in the training loop that starts with a target of length 1 and passes incrementally larger targets until it passes the entire target. Then calculate loss based on how similar the output is to the target shifted left by one. (I also do backpropagation each time – not sure if that’s correct or if they should be aggregated over the whole sub-loop.) E.g.
[<start>]
goes in,['h']
is expected. Then[<start>, 'h']
goes in,['h', e']
is expected. And so on. The last iteration is[<start>, 'h', 'e', 'l', 'l', 'o' ]
, with[ 'h', 'e', 'l', 'l', 'o', <end>]
expected. This particular way of training is called teacher forcing. It also sets us up nicely to perform inference.
Inference (answering this issue now) then happens by simply passing the hidden state from the encoder and the [<start>]
token to the decoder. Since the model has been trained to output a single token when a single <start>
token is passed in, it should output (hopefully) the correct first token of our output sequence. Then, we can take that token and append it to our <start>
token, and pass in that as input to the decoder. Now it should generate two tokens. We repeat this process until the <end>
token is generated, and then we stop. This is known as greedy decoding. Both teacher forcing and greedy decoding are used to train Google’s T5, so they’re viable today. There is, however, a method called beam search that gets better results, but takes much longer to generate.