Training a many-to-many recurrent neural network

Hi all,

I’m trying to train a many-to-many recurrent neural network and I have some questions as this is my first attempt at training RNNs.

So here’s the situation: I have some datapoints (let’s say approx. 400) that are sequential in time. Every datapoint is a vector of 1025 elements. Every corresponding output is an image of 256x256 pixels. My goal scenario is where I can feed my neural network one datapoint to reconstruct an image, and at some later time I can feed it a next datapoint and reconstruct a new image. As these datapoints are acquired over time and the in- and output are interrelated my thought was to train an RNN to reconstruct an image for every datapoint such that time-dependent factors are encoded in the hidden state.

So I’ve implemented a model like this:

class RNN(torch.nn.Module):
    def __init__(self, datapoint_size, image_size, input_size=512, hidden_size=4096):
        super(RNN, self).__init__()
        self.image_size = image_size
        self.datapoint_size = datapoint_size
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.hidden = None
        self.fc = torch.nn.Linear(in_features=2 * self.datapoint_size + 1, out_features=self.input_size)
        self.gru = torch.nn.GRU(self.input_size, self.hidden_size, num_layers=2)
        
        self.conv1 = DoubleConv(1, 32)
        self.up1 = torch.nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2)
        self.conv2 = DoubleConv(16, 8)
        self.up2 = torch.nn.ConvTranspose2d(8, 1, kernel_size=2, stride=2)

    def reset_hidden(self):
        self.hidden = None
        
    def forward(self, X):
        batch, seq_len, i_shape = X.shape

        X = self.fc(X).view(seq_len, batch, self.input_size)

        X, self.hidden = self.gru(X, self.hidden)
        X = X.view(seq_len, 1, 64, 64)
        X = self.conv1(X)
        X = self.up1(X)
        X = self.conv2(X)
        X = self.up2(X)
        return X

model = RNN(512, 256).to(device)

Where DoubleConv is implemented like this:

class DoubleConv(torch.nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            torch.nn.BatchNorm2d(out_channels),
            torch.nn.ReLU(inplace=True),
            torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            torch.nn.BatchNorm2d(out_channels),
            torch.nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

And train it like this:

l2_loss = torch.nn.MSELoss()
lr = 1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.99))
points = 403 # number of datapoints
for epoch in tqdm(range(100)):
    model.reset_hidden()
    mean_err = []
    for point in tqdm(range(points)):
        single_data = data[:, point, ...].unsqueeze(1)
        single_target = target[point, ...].unsqueeze(0)
        optimizer.zero_grad()
        output = model(single_data)
        loss = l2_loss(output, single_target)
        loss.backward(retain_graph=True)
        optimizer.step()

With data having shape torch.Size([1, 403, 1025]) and target having shape torch.Size([403, 1, 256, 256]).
I’ve made the training process like this to ensure that the hidden state does not reset between data points.

Now, I’ve got a few questions:

  1. Is this the way to implement it? It seems quite complicated for what I think should be quite easy.
  2. It seems to work but it is very slow. Training one epoch with one timeseries consisting of 400 datapoints takes 25 minutes on the GPU. Probably because it has to backpropagate over all timepoints but I think it must be done faster because this will not scale when training on the full dataset.
  3. How does this extend to multiple batches? I’m assuming one batch then consists of one time point for multiple time series?

If you have any other ideas/suggestions on how to do what I’m doing better please let me know!
Thanks for reading.

Cheers