nn.Transformer explaination

Wow, thanks for the quick reply.

Which attn_mask is that? Both source and target masks should be pretty standard

Here’s how I’m using it, where self.base is just a model that returns embeddings for inp (src) and tgt, and where src_mask and tgt_mask are the standard upper triangle matrices, and src/tgt_key_padding_mask are as I described previously:

inp_emb, tgt_emb = self.base(inputs, targets)
# We get inputs and targets in (N, S, E) and (N, T, E), and nn.Transformer requires (S, N, E) and (T, N, E), so we transpose them
inp_emb = inp_emb.transpose(0, 1)
tgt_emb = tgt_emb.transpose(0, 1)

hdn = self.transformer(inp_emb, tgt_emb, src_mask=src_mask, tgt_mask=tgt_mask, src_key_padding_mask=inp_padding_mask, tgt_key_padding_mask=tgt_padding_mask)

out = self.head(hdn)
out = out.transpose(0, 1)

loss_fct = nn.CrossEntropyLoss()
out_view = out.contiguous().view(-1, self.vocab_size)
tgt_view = targets.view(-1)
loss = loss_fct(out_view, tgt_view)

Could the transposes be throwing it off?