Passing hidden layers to ConvLSTM

I am new to pytorch, here is my question. I have implemented ConvLSTM on pytorch by I could not find a way to initialize the hidden states before unrolling the ConvLSTM. This is the ConvLSTM cell and layer.

import torch
import torch.nn as nn
from torch.autograd import Variable


class ConvLSTMCell(nn.Module):
    def __init__(self, input_channels, hidden_channels, kernel_size):
        super(ConvLSTMCell, self).__init__()

        #assert hidden_channels % 2 == 0

        self.input_channels = input_channels
        self.hidden_channels = hidden_channels
        self.kernel_size = kernel_size
        self.num_features = 4

        self.padding = int((kernel_size - 1) / 2)

        self.Wxi = nn.Conv2d(self.input_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=True)
        self.Whi = nn.Conv2d(self.hidden_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=False)
        self.Wxf = nn.Conv2d(self.input_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=True)
        self.Whf = nn.Conv2d(self.hidden_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=False)
        self.Wxc = nn.Conv2d(self.input_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=True)
        self.Whc = nn.Conv2d(self.hidden_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=False)
        self.Wxo = nn.Conv2d(self.input_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=True)
        self.Who = nn.Conv2d(self.hidden_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=False)

        self.Wci = None
        self.Wcf = None
        self.Wco = None

    def forward(self, x, h, c):
        ci = torch.sigmoid(self.Wxi(x) + self.Whi(h) + c * self.Wci)
        cf = torch.sigmoid(self.Wxf(x) + self.Whf(h) + c * self.Wcf)
        cc = cf * c + ci * torch.tanh(self.Wxc(x) + self.Whc(h))
        co = torch.sigmoid(self.Wxo(x) + self.Who(h) + cc * self.Wco)
        ch = co * torch.tanh(cc)
        return ch, cc

    def init_hidden(self, batch_size, hidden, shape):
        if self.Wci is None:
            self.Wci = nn.Parameter(torch.zeros(1, hidden, shape[0], shape[1])).cuda()
            self.Wcf = nn.Parameter(torch.zeros(1, hidden, shape[0], shape[1])).cuda()
            self.Wco = nn.Parameter(torch.zeros(1, hidden, shape[0], shape[1])).cuda()
        else:
            assert shape[0] == self.Wci.size()[2], 'Input Height Mismatched!'
            assert shape[1] == self.Wci.size()[3], 'Input Width Mismatched!'
        return (nn.Parameter(torch.zeros(batch_size, hidden, shape[0], shape[1])).cuda(),
                nn.Parameter(torch.zeros(batch_size, hidden, shape[0], shape[1])).cuda())

class ConvLSTM_s(nn.Module):
    # input_channels corresponds to the first input feature map
    # hidden state is a list of succeeding lstm layers.
    def __init__(self, input_channels, hidden_channels, kernel_size, step=1, effective_step=[1]):
        super(ConvLSTM_s, self).__init__()
        self.input_channels = [input_channels] + hidden_channels
        self.hidden_channels = hidden_channels
        self.kernel_size = kernel_size
        self.num_layers = len(hidden_channels)
        self.step = step
        self.effective_step = effective_step
        self._all_layers = []
        for i in range(self.num_layers):
            name = 'cell{}'.format(i)
            cell = ConvLSTMCell(self.input_channels[i], self.hidden_channels[i], self.kernel_size)
            setattr(self, name, cell)
            self._all_layers.append(cell)

    def forward(self, inputx, states=None):
        internal_state = []
        outputs = []
        for step in range(self.step):

            x = inputx[:,step,:,:,:] #slicing the time window
            #print(x.shape)
            for i in range(self.num_layers):
                # all cells are initialized in the first step
                name = 'cell{}'.format(i)
                if step == 0:
                    bsize, _, height, width = x.size()
                    if states == None:
                    (h, c) = getattr(self, name).init_hidden(batch_size=bsize, hidden=self.hidden_channels[i],                                          shape=(height, width))
                    else:
                    (h,c) = states
                    
                    internal_state.append((h, c))

                # do forward
                (h, c) = internal_state[i]
                x, new_c = getattr(self, name)(x, h, c)
                internal_state[i] = (x, new_c)
            # only record effective steps
            #if step in self.effective_step:
            if True:
                outputs.append(x)
        assert len(outputs) == inputx.shape[1]
        outputs = torch.stack(outputs,dim=1)
        return (outputs, internal_state[-1])

if __name__ == '__main__':

    inputx = Variable(torch.randn(1, 3, 4, 64, 32)).cuda()

    h1 = (torch.zeros(1, 2, 64, 32)).cuda()
    h2 = (torch.zeros(1, 2, 64, 32)).cuda()

    target = Variable(torch.randn(1, 3, 2, 64, 32)).double().cuda()

    convlstms = ConvLSTM_s(input_channels=4, hidden_channels=[2], kernel_size=3, step=3,
                        effective_step=[3]).cuda()
    loss_fn = torch.nn.MSELoss()




    output = convlstms(inputx,states=(h1,h2))

And the error is

     32     def forward(self, x, h, c):
---> 33         ci = torch.sigmoid(self.Wxi(x) + self.Whi(h) + c * self.Wci)
     34         cf = torch.sigmoid(self.Wxf(x) + self.Whf(h) + c * self.Wcf)
     35         cc = cf * c + ci * torch.tanh(self.Wxc(x) + self.Whc(h))

TypeError: mul(): argument 'other' (position 1) must be Tensor, not NoneType

It stems from the operations in the ConvLSTM cell

Hi!

After sifting through some other posts, I reckon you already managed to overcome this problem. It seems, judging from the code, that you’re performing the initial pass with a sequence and are trying to perform multiplication of the previous cell state c with an un-initialized input gate for the cell state self.Wci. In other words, performing the multiplication against a None will throw that error.

If I may, here’s my take on this topic according to the paper by Shi et al. (2015).

My approach was to implement a convolutional LSTM cell first (to similar fashion that you did), which will then be utilized in a complete model. In other words, these are the inner workings of the cell. The recurrent operations (looping, passing states to subsequent steps etc.) should be handled in a separate ConvLSTM class and its forward function.

Here’s the simple source the class:

import torch
from torch import nn


def initialize_weights(self, layer):
        """Initialize a layer's weights and biases.

        Args:
            layer: A PyTorch Module's layer."""
        if isinstance(layer, (nn.BatchNorm2d, nn.BatchNorm1d)):
            pass
        else:
            try:
                nn.init.xavier_normal_(layer.weight)
            except AttributeError:
                pass
            try:
                nn.init.uniform_(layer.bias)
            except (ValueError, AttributeError):
                pass

class HadamardProduct(nn.Module):
    """A Hadamard product layer.
    
    Args:
        shape: The shape of the layer."""
       
    def __init__(self, shape):
        super().__init__()
        self.weights = nn.Parameter(torch.empty(*shape))
        self.bias = nn.Parameter(torch.empty(*shape))
           
    def forward(self, x):
        return x * self.weights

    
class ConvLSTMCell(BaseModule):
    """A convolutional LSTM cell.

    Implementation details follow closely the following paper:

    Shi et al. -'Convolutional LSTM Network: A Machine Learning 
    Approach for Precipitation Nowcasting' (2015).
    Accessible at https://arxiv.org/abs/1506.04214

    The parameter names are drawn from the paper's Eq. 3.

    Args:
        input_bands: The number of bands in the input data.
        input_dim: The length of of side of input data. Data is
            presumed to have identical width and heigth."""

    def __init__(self, input_bands, input_dim,  kernels, dropout, batch_norm):
        super().__init__()

        self.input_bands = input_bands
        self.input_dim = input_dim
        self.kernels = kernels
        self.dropout = dropout
        self.batch_norm = batch_norm

        self.kernel_size = 3
        self.padding = 1  # Preserve dimensions

        self.input_conv_params = {
            'in_channels': self.input_bands,
            'out_channels': self.kernels,
            'kernel_size': self.kernel_size,
            'padding': self.padding,
            'bias': True
        }

        self.hidden_conv_params = {
            'in_channels': self.kernels,
            'out_channels': self.kernels,
            'kernel_size': self.kernel_size,
            'padding': self.padding,
            'bias': True
        }

        self.state_shape = (
            1,
            self.kernels,
            self.input_dim,
            self.input_dim
        )

        self.batch_norm_layer = None
        if self.batch_norm:
            self.batch_norm_layer = nn.BatchNorm2d(num_features=self.input_bands)

        # Input Gates
        self.W_xi = nn.Conv2d(**self.input_conv_params)
        self.W_hi = nn.Conv2d(**self.hidden_conv_params)
        self.W_ci = HadamardProduct(self.state_shape)

        # Forget Gates
        self.W_xf = nn.Conv2d(**self.input_conv_params)
        self.W_hf = nn.Conv2d(**self.hidden_conv_params)
        self.W_cf = HadamardProduct(self.state_shape)

        # Memory Gates
        self.W_xc = nn.Conv2d(**self.input_conv_params)
        self.W_hc = nn.Conv2d(**self.hidden_conv_params)

        # Output Gates
        self.W_xo = nn.Conv2d(**self.input_conv_params)
        self.W_ho = nn.Conv2d(**self.hidden_conv_params)
        self.W_co = HadamardProduct(self.state_shape)

        # Dropouts
        self.H_drop = nn.Dropout2d(p=self.dropout)
        self.C_drop = nn.Dropout2d(p=self.dropout)

        self.apply(initialize_weights)

A simple example on how the cell is utilized is as follows:

cell = ConvLSTMCell(
    input_bands=3,
    input_dim=16,
)
batch = torch.ones((1,3,16,16))
H,C = cell(batch)

print(f"H:{H.shape}")
print(f"C:{H.shape}")

>>> H:torch.Size([1, 32, 16, 16])
>>> C:torch.Size([1, 32, 16, 16])

Then to utilize the ConvLSTMCell in a complete LSTM module, you could use a model like this for example:

class ConvLSTM(nn.Module):

    def __init__(self, input_bands, input_dim, kernels, num_layers, bidirectional, dropout):
        super().__init__()
        self.input_bands = input_bands
        self.input_dim = input_dim
        self.kernels = kernels
        self.num_layers = num_layers
        self.bidirectional = bidirectional
        self.dropout = dropout
        
        self.layers_fwd = self.initialize_layers()
        self.layers_bwd = None
        if self.bidirectional:
            self.layers_bwd = self.initialize_layers()
        self.fc_output = nn.Sequential(
            nn.Flatten(),
            nn.Linear(
                in_features=self.kernels*self.input_dim**2*(1 if not self.bidirectional else 2), 
                out_features=1024
            ),
            nn.Linear(
                in_features=1024, 
                out_features=1
            )
        )
            
        self.apply(initialize_weights)
        
    def initialize_layers(self):
        """Initialize a single direction of the model's layers.
        
        This function takes care of stacking layers, allocating
        dropout and assigning correct input feature number for
        each layer in the stack."""
        
        layers = nn.ModuleList()
        
        for i in range(self.num_layers):
            layers.append(
                ConvLSTMCell(
                    input_bands=self.input_bands if i == 0 else self.kernels, 
                    input_dim=self.input_dim,
                    dropout=self.dropout if i+1 < self.num_layers else 0,
                    kernels=self.kernels,
                    batch_norm=False
                )
            )
            
        return layers
    
        
    def forward(self, x):
        """Perform forward pass with the model.
        
        For each item in the sequence, the data is propagated 
        through each layer and both directions, if possible.
        In case of a bidirectional model, the outputs are 
        concatenated from both directions. The output of the 
        last item of the sequence is further given to the FC
        layers to produce the final batch of predictions. 
        
        Args:
            x:  A batch of spatial data sequences. The data
                should be in the following format:
                [Batch, Seq, Band, Dim, Dim]
                    
        Returns:
            A batch of predictions."""
        
        seq_len = x.shape[1]
        
        for seq_idx in range(seq_len):
            
            layer_in_out = x[:,seq_idx,::]
            states = None
            for layer in self.layers_fwd:
                layer_in_out, states = layer(layer_in_out, states)
                
            if not self.bidirectional:
                continue
                
            layer_in_out_bwd = x[:,-seq_idx,::]
            states = None
            for layer in self.layers_bwd:
                layer_in_out_bwd, states = layer(layer_in_out_bwd, states)
            
            layer_in_out = torch.cat((layer_in_out,layer_in_out_bwd),dim=1)
            
        return self.fc_output(layer_in_out)

Of course these are freely editable to better fir your purposes. The above ConvLSTM for example produces a many-to-one prediction, ingesting sequences and producing single values using only the last outputs of the sequential model (or models in the case of bidirectionality).

I sincerely hope this helps you and others tackling this topic!

Cheers,
Petteri

2 Likes

TypeError: initialize_weights() missing 1 required positional argument: 'layer'

Thanks for sharing your ConvLSTM implementation! Could you still share the BaseModule class that the ConvLSTMCell inherits? Currently, the description for the forward pass of the LSTM cell is missing. BR, mikko

1 Like

Still having this issue? I tried testing and stumble on same error message