I have a dataset of images and then a continuous value. I’m using a CNN model to predict that value from the image. The images and values are in a time series, so I’d like to connect a GRU or LSTM to the CNN layers, but I’m having trouble doing so. The input is images that are (1,360,360) since they’re black and white, and I’d like to input a sequence of 12 images in the series at a time. This is my model thus far:
class CNN_GRU(nn.Module):
def __init__(self):
## Pool, ReLU, Dropout
super().__init__()
## CNN
self.cnn = torch.nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=4, kernel_size=9, stride=5),
nn.ReLU(),
nn.MaxPool2d(kernel_size=4, stride=(2,2)),
nn.Conv2d(in_channels=4, out_channels=8, kernel_size=5, stride=1, padding='same'),
nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=(2,2)),
nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3, stride=1, padding='same'),
nn.ReLU(),
nn.Conv2d(in_channels=16, out_channels=8, kernel_size=3, stride=1, padding='same'),
nn.ReLU(),
nn.Conv2d(in_channels=8, out_channels=8, kernel_size=3, stride=1, padding='same'),
nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=(2,2)),
nn.ReLU(),
nn.Dropout(0.5)
)
## LINEAR
self.linear = torch.nn.Sequential(
nn.Linear(in_features = 784 ,out_features=200),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(in_features=200, out_features=100),
nn.ReLU(),
nn.Dropout(.25),
nn.Linear(in_features=100, out_features=20),
nn.ReLU(),
nn.Linear(in_features=20, out_features=1)
)
## GRU layer
self.GRU = nn.GRU(input_size=392, hidden_size=1, num_layers=3, batch_first=True)
def forward(self, x):
print("0: ", x.shape)
cnn_output = self.cnn(x)
print("Cnn output: ", cnn_output.shape)
cnn_output = cnn_output.flatten()
print("Cnn output: ", cnn_output.shape)
gru_output, _h_o = self.GRU(cnn_output)
print("gru_output: ", gru_output.shape)
linear_output = self.linear(gru_output)
return linear_output
The error I get is that the GRU expects 3 dimensions (batch, sequence length, input size) and my input is 5 dimensions. How can I reshape my input, or am I going about this in the complete wrong way?