I was trying to use nn.Transformer without decoding layers and I found the behavior of the model with num_decoder_layers=0 to be a bit confusing.
Perhaps it’s just my point of view.
When I set num_decoder_layers to 0, I would like the model to do:
- out = encoders(src)
but what the model seems to be doing is:
- mem = encoders(src)
- out = norm(tgt)
torch.nn.modules.transformer — PyTorch 2.0 documentation (decoder(tgt, mem) without decoding layers seems to be returning norm(tgt))
It seems that the “issue” has already been notified: [docs] `nn.Transformer` is not possible to be used to implement BERT · Issue #68053 · pytorch/pytorch · GitHub / How to implement BERT using torch.nn.Transformer?
I guess the official way to do this is by using nn.TransformerEncoder and nn.TransformerEncoderLayer.
It would perhaps be less confusing if this line of code was added to the init of nn.Transformer?
- self.decoder = self.decoder if num_decoder_layers else lambda dec_args: dec_args
(or a cleaner version of this)
or, if the setting “num_decoder_layers=0” isn’t expected, an assert could be used to avoid this confusing behavior.