ConvLSTM providing unreliable predictions

Hello good people,
I am learning pytorch. Currently, I am trying to predict airquality over a grid of 2394 cells. I have 7 variables collected from 11 locations in the grid. My batched input is (batchsize, timestamps, no_locations, no_of_varaibles) and traget is (batchsize, 1, 2394, 1). I want the model to consider 18 previous timestamps. Shapes of my train and test are as follows:

x_train.shape, y_train.shape
((8438, 18, 11, 7), (8438, 2394, 1))

Each batch is like:

((128, 18, 11, 7), (128, 2394, 1))

Here is my model:

import torch
import torch.nn as nn

"""
Reference: https://github.com/spacejake/convLSTM.pytorch/blob/master/convlstm.py

"""


class ConvLSTMCell(nn.Module):

    def __init__(self, in_size, in_channels, h_channels, kernel_size, bias=True):
        """
        params:
            in_size (int, int) - height and width of input tensor as (height, width)
            in_channels (int) - number of channels in the input image
            h_channels (int) - number of channels of hidden state
            kernel_size (int, int) - size of the convolution kernel
            bias (bool, optional) - default: True
        """

        super(ConvLSTMCell, self).__init__()

        self.height, self.width = in_size
        self.h_channels = h_channels
        padding = kernel_size[0] // 2, kernel_size[1] // 2

        self.conv = nn.Conv2d(in_channels=in_channels + h_channels,
                              out_channels=4 * h_channels,
                              kernel_size=kernel_size,
                              padding=padding,
                              bias=bias)

    def forward(self, input_data, prev_state):
        h_prev, c_prev = prev_state
        combined = torch.cat((input_data, h_prev), dim=1)  # concatenate along channel axis

        combined_output = self.conv(combined)
        cc_i, cc_f, cc_o, cc_g = torch.split(combined_output, self.h_channels, dim=1)

        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f)
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)

        c_cur = f * c_prev + i * g
        h_cur = o * torch.tanh(c_cur)

        return h_cur, c_cur

    def init_hidden(self, batch_size, device):
        """ initialize the first hidden state as zeros """
        return (torch.zeros(batch_size, self.h_channels, self.height, self.width).to(device),
                torch.zeros(batch_size, self.h_channels, self.height, self.width).to(device))


class ConvLSTM(nn.Module):

    def __init__(self, in_size, in_channels, h_channels, kernel_size, num_layers, **kwargs):

        super(ConvLSTM, self).__init__()

        self._check_kernel_size_consistency(kernel_size)

        self.height, self.width = in_size
        self.num_layers = num_layers
        self.batch_first = kwargs.get('batch_first', True)
        self.output_last = kwargs.get('output_last', True)
        self.device = kwargs.get('device', 'cpu')

        self.cell_list = nn.ModuleList()
        for i in range(0, self.num_layers):
            cur_in_channels = in_channels if i == 0 else h_channels[i - 1]
            self.cell_list.append(ConvLSTMCell(in_size=(self.height, self.width),
                                               in_channels=cur_in_channels,
                                               h_channels=h_channels[i],
                                               kernel_size=kernel_size))
            
        self.linear_layers = Sequential(Linear(32*11, 128), 
                                        ELU(inplace=True),
                                        Linear(128, 128), 
                                        ELU(inplace=True),
                                        Linear(128, 128), 
                                        ELU(inplace=True),
                                        Linear(128, 256), 
                                        ELU(inplace=True),
                                        Linear(256, 2394)
                                       )

    def forward(self, input_data, hidden_state=None):
        """
        params:
            input_data (batch_size, seq_len, num_channels, height, width)
            hidden_state: None

        return:
            hidden_states_list[-1], last_state_list[-1] | hidden_states_list, last_state_list
        """

        if not self.batch_first:  # (t, b, c, h, w) -> (b, t, c, h, w)
            input_data = input_data.permute(1, 0, 2, 3, 4)

        if hidden_state is None:
            hidden_state = self.get_init_states(batch_size=input_data.size(0))

        hidden_states_list, last_state_list = [], []

        seq_len = input_data.size(1)
        cur_layer_input = input_data

        for i in range(self.num_layers):
            h, c = hidden_state[i]
            hidden_states = []
            for t in range(seq_len):
                h, c = self.cell_list[i](input_data=cur_layer_input[:, t, :, :, :], prev_state=[h, c])
                hidden_states.append(h)

            hidden_states = torch.stack(tuple(hidden_states), dim=1)
            cur_layer_input = hidden_states  # the output of (n) hidden layer is the input of (n+1) hidden layer

            hidden_states_list.append(hidden_states)
            last_state_list.append((h, c))

        #if self.output_last:
        #    return hidden_states_list[-1], last_state_list[-1]
        #else:
        #    return hidden_states_list, last_state_list
        x = last_state_list[-1][0]
        ##print(x.shape)
        #[64, 32, 1, 34]
        x = x.reshape(x.shape[0], -1)
        # print(x.shape)
        x = self.linear_layers(x)
        return x

    def get_init_states(self, batch_size):
        init_states = []
        for i in range(self.num_layers):
            init_states.append(self.cell_list[i].init_hidden(batch_size, self.device))
        return init_states

    @staticmethod
    def _check_kernel_size_consistency(kernel_size):
        """ kernel size should be in the tuple mode (k, k) """
        if not isinstance(kernel_size, tuple):
            raise ValueError('kernel_size must be tuple')


def train_one_epoch(batch_size, train_loader):
    tloss = 0.
    for i, data in enumerate(train_loader):
        inputs, labels = data
        
        #input_data (batch_size, seq_len, num_channels, height, width)
        # Make predictions for this batch
        #torch.Size([64, 36, 34, 7])
        #print(inputs.shape)
        inputs = torch.permute(inputs, (0, 1, 3, 2))
        inputs = inputs.unsqueeze(3)
        
        outputs = model(inputs)
        
        
        #outputs = outputs.unsqueeze(-1)
        labels = labels.view(labels.shape[0], -1)
        
        #check if all non in the target
        if len(labels[~torch.isnan(labels)]) > 0:
            train_loss = loss_func(outputs[~torch.isnan(labels)], 
                                    labels[~torch.isnan(labels)]
                                  )
            tloss += train_loss.item()

            # Zero your gradients for every batch!
            optimizer.zero_grad()
            train_loss.backward()
            # Adjust learning weights
            optimizer.step()
    
    tloss /= len(train_loader)
    
    return tloss

#in_size, in_channels, h_channels, kernel_size, num_layers, **kwargs):
model = ConvLSTM(in_size=(1, insites),
                in_channels=input_dim,
                h_channels=[32, 64],
                kernel_size=(1,3),
                num_layers=1,
                device=device
               )
model.to(device)


lr = 0.001

savemodel = True

num_epochs = 1500
best_vloss = 1_000_000.
best_epoch = -1
epoch_number = 0

loss_func = nn.L1Loss(reduction='mean')
#loss_func = nn.MSELoss(reduction='mean')
#optimizer = torch.optim.Adam(model.parameters(), lr=lr)
optimizer = torch.optim.SGD(model.parameters(), lr=lr)

%timeit
for epoch in tqdm(range(num_epochs)):
    model.train()
    avg_loss = train_one_epoch(batch_size, train_loader)
    
    # validation step
    model.eval()
    avg_vloss = valid_one_epoch(batch_size, valid_loader)
    
    trainloss.append(avg_loss)
    validloss.append(avg_vloss)
    
    print(f'EPOCH {epoch_number+1}: LOSS train {avg_loss:.2f} valid {avg_vloss:.2f}')
    
    # Track best performance, and save the model's state
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_path = f'model_{epoch_number}'
        torch.save(model.state_dict(), model_path)
        print(f"Saved at epoch: {epoch_number}")
        best_epoch = epoch_number

    epoch_number += 1

The problem is for what ever architecture I use the output is almost similar. I am really confused and spent a lot of time to debug. Any help will be highly appriciated. Thank you in advance.

Tagging a few good people.
@ptrblck @vdw