GRU: many-to-one / one-to-many

Hi everyone,

I would like to implement a GRU able to encode a sequence of vectors to one vector (many-to-one), and then another GRU able to decode a vector to a sequence of vector (one-to-many). The size of the vectors wouldn’t be changed. I would like to have an opinion about what I implemented.

Here is the code:

class GRU(nn.Module):
    def __init__(self, opt):
        super(GRU, self).__init__()

        self.input_size = 512*3
        self.length_sequence = 30
        self.hidden_size = self.input_size
        self.num_layers = 1

        self.GRU_enc = nn.GRU(input_size=self.input_size, hidden_size=self.hidden_size, num_layers=self.num_layers, batch_first=True)

        self.GRU_dec = nn.GRU(input_size=self.input_size, hidden_size=self.hidden_size, num_layers=self.num_layers, batch_first=True)
    
    # many-to-one
    def enc(self, x):
        # x is of shape BxLx512*3

        h0 = torch.zeros(self.num_layers, x.shape[0], self.hidden_size).cuda()

        out, _ = self.GRU_enc(x, h0)

        out_last = out[:, -1, :]

        return out_last

    # one-to-many
    def dec(self, x):
        x = x[:, None, ...]

        h = torch.zeros(self.num_layers, x.shape[0], self.hidden_size).cuda()

        outputs = []

        for i in range(self.length_sequence):
            out, h = self.GRU_dec(x, h)
            outputs.append(out)

        output = torch.cat(outputs, dim=1)

        output = output.view(-1, self.length_sequence, self.input_size)

        return output

    def forward(self, x):
        one = self.enc(x)

        many = self.dec(one)

        return many

I am not sure whether this is the good way to do a one-to-many GRU. Could I have some opinions about this?

Also, do you think I need to add some Linear layers after the encoding / decoding, or some activation function?

Finally, the number of parameters is quite big (28M with one layer, 55M with two layers). How many samples would I need in my dataset in order to be able to learn something?

Thanks for reading!