Seq2Seq regression problem with attention

Hello,

I wrote the following code to solve a Seq2Seq regression problem. My implementation is based on the GRU and multi-head attention. The performance is horrible. I tried playing with the hyperparameters, but nothing changed. This led me to think it was a network architecture issue.

``````class Seq2Seq(nn.Module):
def __init__(self, input_size, output_size, hidden, num_heads):
super(Seq2Seq, self).__init__()
self.encoder = nn.GRU(input_size, hidden, 2)
self.decoder = nn.GRU(hidden, hidden, 2)
self.linear  = nn.Linear(hidden, output_size)
self.init_weights()

def init_weights(self):
self.linear.weight.data.normal_(0, 0.1)

def forward(self, x):
encoded, _ = self.encoder(x)
decoded, _ = self.decoder(encoded)
attention_output, _ = self.multihead_attn(decoded, decoded, decoded)
out = self.linear(attention_output)
return out

D_in    = 4
D_out   = 1
hidden = 16
inputs = torch.rand((7, 100, D_in))
outputs = seq2seq(inputs)

``````

Any suggestions are highly appreciated

I donâ€™t have a real answer, just some food for thoughts:

• Iâ€™m not sure intuitive it is to use `nn.MultiHeadAttention` on the output of a `nn.GRU`. `nn.MultiHeadAttention` basically implements self-attention which generally assumes that the sequence elements are â€śindependentâ€ť like word (vectors). However the output of a `nn.GRU` is different as the output at step T captures to some extend the outputs from all previous steps (T-1).

• At lease from by basic experience, transformers a difficult to train from scratch; usually you use pretrained models.

• Strictly speaking your model does not implement a Seq2Seq task but a sequence labeling tasks. i.e., you get an output for each input word/item. I actually can see what kind of regression problem youâ€™re trying to solve

• Have you tried a more basic model by just using the `nn.GRU`? How do the results compare? Itâ€™s often better to first try a simple architectures and then extending it.

1 Like

@vdw, many thanks for the insights. When I did not use multi-head attention, I got better results. Thanks again for the useful points.