Using seperate encoder & decoder for transformer

Hello, I’m messing around with transformers right now, and I’m trying to modify the encoded representation with a modified LSTM (the goal is to continue text in a specific style). I’ve found an example on how to use T.nn.TransformerEncoder, but no examples on how to properly use T.nn.TransformerDecoder. How am I supposed to use it? I’ve read about how decoders work in general, but I can’t find anything about the specific pytorch implementation. How should I use it for training vs inference? do I manually have to put the output of the transformer into the tgt during inference, or is that done automatically? What does tgt_is_causal do?

I’ve included a snippet of my code if that’s useful at all.

    def forward(self, x, mhx, tgt= None):
        
        
        embedded_seq = self.embedding(x) * math.sqrt(self.emb_dim)
        embedded_seq = self.pos_encoder(embedded_seq)
        if src_mask is None:
            """Generate a square causal mask for the sequence. The masked positions are filled with float('-inf').
            Unmasked positions are filled with float(0.0).
            """
            src_mask = nn.Transformer.generate_square_subsequent_mask(len(embedded_seq)).to(device)
        encoded_seq = self.transformer_encoder(embedded_seq, src_mask)
        
        
        
        # Take the encoded sequence, repeat once over the time axis, and stick that into the DNC. (to give the DNC time to analyze & plan)
        processed_seq, (chx, mhx, rv) = self.vector_machine(T.cat( (encoded_seq, encoded_seq),1), (None, mhx, None), reset_experience=True, pass_through_memory=True)

        #split the processed sequence into two parts, taking the second half and adding it to the encoded sequence as a skip layer.
        processed_seq = T.chunk(processed_seq,2,dim=1)[1] + encoded_seq
        
        #TODO: put decoder here. 
        
        return decoded_seq, (chx, mhx, rv)

Is there any example for proper use of the pytorch transformer decoder?