Normalization between stacked ConvLSTM

Hello:) I’m trying to add norm layer between few Convlstm layers, but it’s hard to find other implementations. (convlstm code is from GitHub - ndrplz/ConvLSTM_pytorch: Implementation of Convolutional LSTM in PyTorch.)

The code for Convlstm is like below.

class ConvLSTMCell(nn.Module):

    def __init__(self, input_dim, hidden_dim, kernel_size, bias):
        """

        Initialize ConvLSTM cell.

        Parameters

        ----------

        input_dim: int

            Number of channels of input tensor.

        hidden_dim: int

            Number of channels of hidden state.

        kernel_size: (int, int)

            Size of the convolutional kernel.

        bias: bool

            Whether or not to add the bias.

        """

        super(ConvLSTMCell, self).__init__()

        self.input_dim = input_dim

        self.hidden_dim = hidden_dim

        self.kernel_size = kernel_size

        self.padding = kernel_size[0] // 2, kernel_size[1] // 2

        self.bias = bias

        self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,

                              out_channels=4 * self.hidden_dim,

                              kernel_size=self.kernel_size,

                              padding=self.padding,

                              bias=self.bias)

    def forward(self, input_tensor, cur_state):

        h_cur, c_cur = cur_state

        combined = torch.cat([input_tensor, h_cur], dim=1)  # concatenate along channel axis

        combined_conv = self.conv(combined)

        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)

        i = torch.sigmoid(cc_i)

        f = torch.sigmoid(cc_f)

        o = torch.sigmoid(cc_o)

        g = torch.tanh(cc_g)

        c_next = f * c_cur + i * g

        h_next = o * torch.tanh(c_next)

        return h_next, c_next

    def init_hidden(self, batch_size, image_size):

        height, width = image_size

        

        return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device),

                torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device))

What I’m trying to do is stacking 3 convlstm layers for both Encoder and Decoder architecture for video prediciton like below. The goal of the model is predicting next 5 frames of video when past 5 frames are given with 3 Convlstm Encoder-Decoder model.

class EncoderDecoderConvLSTM(nn.Module):

    def __init__(self, nf, in_chan): # (64,3)

        super(EncoderDecoderConvLSTM, self).__init__()

        """ ARCHITECTURE 

        # Encoder (ConvLSTM)

        # Encoder Vector (final hidden state of encoder)

        # Decoder (ConvLSTM) - takes Encoder Vector as input

        # Decoder (3D CNN) - produces regression predictions for our model

        """

        self.encoder_1_convlstm = ConvLSTMCell(input_dim=in_chan,

                                               hidden_dim=nf,

                                               kernel_size=(5, 5),

                                               bias=True)

        self.encoder_2_convlstm = ConvLSTMCell(input_dim=nf,

                                               hidden_dim=nf,

                                               kernel_size=(5, 5),

                                               bias=True)

        

        self.encoder_3_convlstm = ConvLSTMCell(input_dim=nf,

                                               hidden_dim=nf,

                                               kernel_size=(5, 5),

                                               bias=True)    

        self.decoder_1_convlstm = ConvLSTMCell(input_dim=nf,  # nf + 1

                                               hidden_dim=nf,

                                               kernel_size=(5, 5),

                                               bias=True)

        self.decoder_2_convlstm = ConvLSTMCell(input_dim=nf,

                                               hidden_dim=nf,

                                               kernel_size=(5, 5),

                                               bias=True)

        self.decoder_3_convlstm = ConvLSTMCell(input_dim=nf,

                                               hidden_dim=nf,

                                               kernel_size=(5, 5),

                                               bias=True)

                

        self.decoder_CNN = nn.Conv3d(in_channels=nf,

                                     out_channels=3, #(1)

                                     kernel_size=(1, 5, 5), # (1,3,3)

                                     padding=(0, 2, 2)) # (0,1,1)

    def autoencoder(self, x, seq_len, future_step, h_t, c_t, h_t2, c_t2, h_t3, c_t3, h_t4, c_t4, 

                    h_t5, c_t5, h_t6, c_t6):

        outputs = []

        # encoder

        for t in range(seq_len):

            h_t, c_t   = self.encoder_1_convlstm(input_tensor=x[:, t, :, :],

                                               cur_state=[h_t, c_t])  # we could concat to provide skip conn here

            h_t2, c_t2 = self.encoder_2_convlstm(input_tensor=h_t,

                                                 cur_state=[h_t2, c_t2])  # we could concat to provide skip conn here

            h_t3, c_t3 = self.encoder_3_convlstm(input_tensor=h_t2,

                                                 cur_state=[h_t3, c_t3])  # we could concat to provide skip conn here                                 

        # encoder_vector

        encoder_vector = h_t3

        # decoder

        for t in range(future_step):

            h_t4, c_t4 = self.decoder_1_convlstm(input_tensor=encoder_vector,

                                                 cur_state=[h_t4, c_t4])  # we could concat to provide skip conn here

            h_t5, c_t5 = self.decoder_2_convlstm(input_tensor=h_t4,

                                                 cur_state=[h_t5, c_t5])  # we could concat to provide skip conn here

            h_t6, c_t6= self.decoder_3_convlstm(input_tensor=h_t5,

                                                 cur_state=[h_t6, c_t6])  # we could concat to provide skip conn here                           

            encoder_vector = h_t6

            outputs += [h_t6]  # predictions

        outputs = torch.stack(outputs, 1)

        outputs = outputs.permute(0, 2, 1, 3, 4)

        outputs = self.decoder_CNN(outputs)

        outputs = torch.nn.Tanh()(outputs)

        #outputs = torch.nn.Sigmoid()(outputs) 

        return outputs

        #del outputs

    def forward(self, x, future_seq=0, hidden_state=None):

        """

        Parameters

        ----------

        input_tensor:

            5-D Tensor of shape (b, t, c, h, w)        #   batch, time, channel, height, width

        """

        # find size of different input dimensions

        b, seq_len, _, h, w = x.size()

        # initialize hidden states

        h_t, c_t   = self.encoder_1_convlstm.init_hidden(batch_size=b, image_size=(h, w))

        h_t2, c_t2 = self.encoder_2_convlstm.init_hidden(batch_size=b, image_size=(h, w))

        h_t3, c_t3 = self.encoder_3_convlstm.init_hidden(batch_size=b, image_size=(h, w))

        h_t4, c_t4 = self.decoder_1_convlstm.init_hidden(batch_size=b, image_size=(h, w))

        h_t5, c_t5 = self.decoder_2_convlstm.init_hidden(batch_size=b, image_size=(h, w))

        h_t6, c_t6 = self.decoder_3_convlstm.init_hidden(batch_size=b, image_size=(h, w))

        # autoencoder forward

        outputs = self.autoencoder(x, seq_len, future_seq, h_t, c_t, h_t2, c_t2, h_t3, c_t3, h_t4, c_t4, 

                                   h_t5, c_t5, h_t6, c_t6)

        return outputs

First question is, what normalization technic(batchnorm, layernorm, groupnorm, instancenorm…) would be proper for this RNN family Model using gradient accumulation(16 batch, division of 4)?
As I know of, Layernorm is decent choice for RNN model but have qurious about using layernorm while training with gradient accumulation method.

Second question, As I have no idea where and how to add norm layer in this code, looking for advice…