Seq2Seq regression problem with attention


I wrote the following code to solve a Seq2Seq regression problem. My implementation is based on the GRU and multi-head attention. The performance is horrible. I tried playing with the hyperparameters, but nothing changed. This led me to think it was a network architecture issue.

class Seq2Seq(nn.Module):
    def __init__(self, input_size, output_size, hidden, num_heads):
        super(Seq2Seq, self).__init__()
        self.encoder = nn.GRU(input_size, hidden, 2)
        self.decoder = nn.GRU(hidden, hidden, 2)
        self.multihead_attn = nn.MultiheadAttention(hidden, num_heads)
        self.linear  = nn.Linear(hidden, output_size)

    def init_weights(self):, 0.1)

    def forward(self, x):
        encoded, _ = self.encoder(x)
        decoded, _ = self.decoder(encoded)
        attention_output, _ = self.multihead_attn(decoded, decoded, decoded)
        out = self.linear(attention_output)
        return out

D_in    = 4
D_out   = 1
hidden = 16
num_heads = 4
seq2seq = Seq2Seq(input_size=D_in, output_size=D_out, hidden=hidden, num_heads=num_heads)
inputs = torch.rand((7, 100, D_in))
outputs = seq2seq(inputs)

Any suggestions are highly appreciated

I don’t have a real answer, just some food for thoughts:

  • I’m not sure intuitive it is to use nn.MultiHeadAttention on the output of a nn.GRU. nn.MultiHeadAttention basically implements self-attention which generally assumes that the sequence elements are “independent” like word (vectors). However the output of a nn.GRU is different as the output at step T captures to some extend the outputs from all previous steps (T-1).

  • At lease from by basic experience, transformers a difficult to train from scratch; usually you use pretrained models.

  • Strictly speaking your model does not implement a Seq2Seq task but a sequence labeling tasks. i.e., you get an output for each input word/item. I actually can see what kind of regression problem you’re trying to solve

  • Have you tried a more basic model by just using the nn.GRU? How do the results compare? It’s often better to first try a simple architectures and then extending it.

1 Like

@vdw, many thanks for the insights. When I did not use multi-head attention, I got better results. Thanks again for the useful points.