Hello, I’m messing around with transformers right now, and I’m trying to modify the encoded representation with a modified LSTM (the goal is to continue text in a specific style). I’ve found an example on how to use T.nn.TransformerEncoder, but no examples on how to properly use T.nn.TransformerDecoder. How am I supposed to use it? I’ve read about how decoders work in general, but I can’t find anything about the specific pytorch implementation. How should I use it for training vs inference? do I manually have to put the output of the transformer into the tgt during inference, or is that done automatically? What does tgt_is_causal do?
I’ve included a snippet of my code if that’s useful at all.
def forward(self, x, mhx, tgt= None):
embedded_seq = self.embedding(x) * math.sqrt(self.emb_dim)
embedded_seq = self.pos_encoder(embedded_seq)
if src_mask is None:
"""Generate a square causal mask for the sequence. The masked positions are filled with float('-inf').
Unmasked positions are filled with float(0.0).
"""
src_mask = nn.Transformer.generate_square_subsequent_mask(len(embedded_seq)).to(device)
encoded_seq = self.transformer_encoder(embedded_seq, src_mask)
# Take the encoded sequence, repeat once over the time axis, and stick that into the DNC. (to give the DNC time to analyze & plan)
processed_seq, (chx, mhx, rv) = self.vector_machine(T.cat( (encoded_seq, encoded_seq),1), (None, mhx, None), reset_experience=True, pass_through_memory=True)
#split the processed sequence into two parts, taking the second half and adding it to the encoded sequence as a skip layer.
processed_seq = T.chunk(processed_seq,2,dim=1)[1] + encoded_seq
#TODO: put decoder here.
return decoded_seq, (chx, mhx, rv)
Is there any example for proper use of the pytorch transformer decoder?