Wow, thanks for the quick reply.
Which attn_mask
is that? Both source and target masks should be pretty standard
Here’s how I’m using it, where self.base
is just a model that returns embeddings for inp
(src) and tgt
, and where src_mask
and tgt_mask
are the standard upper triangle matrices, and src/tgt_key_padding_mask
are as I described previously:
inp_emb, tgt_emb = self.base(inputs, targets)
# We get inputs and targets in (N, S, E) and (N, T, E), and nn.Transformer requires (S, N, E) and (T, N, E), so we transpose them
inp_emb = inp_emb.transpose(0, 1)
tgt_emb = tgt_emb.transpose(0, 1)
hdn = self.transformer(inp_emb, tgt_emb, src_mask=src_mask, tgt_mask=tgt_mask, src_key_padding_mask=inp_padding_mask, tgt_key_padding_mask=tgt_padding_mask)
out = self.head(hdn)
out = out.transpose(0, 1)
loss_fct = nn.CrossEntropyLoss()
out_view = out.contiguous().view(-1, self.vocab_size)
tgt_view = targets.view(-1)
loss = loss_fct(out_view, tgt_view)
Could the transposes be throwing it off?