Why is my LSTM model not learning?

Hi, I’m trying to implement spatio-temporal LSTM (ST-LSTM) model for human action recognition using 3D skeleton data, basis on this article: Spatio-Temporal LSTM with Trust Gates for 3D Human Action Recognition | SpringerLink.

Neural network model and single ST-LSTM equations looks like below:
image

as input to ST-LSTM I pass hidden and cell state from previous ST-LSTM in temporal and spatial dimension together with single joint (x, y, z) as input. Shapes of input vectors are:

torch.Size([BATCH_SIZE, 128]) - hidden state
torch.Size([BATCH_SIZE, 128]) - cell state
torch.Size([BATCH_SIZE, 3]) - input (x, y, z) coordinate of single joint

I implemented ST-LSTM cell as below:

STLSTMState = namedtuple('STLSTMState', ['h_temp_prev', 'h_spat_prev', 'c_temp_prev', 'c_spat_prev'])


class STLSTMCell(RNNCellBase):
    def __init__(self, input_size: int, hidden_size: int, bias: bool = True) -> None:
        super(STLSTMCell, self).__init__(input_size, hidden_size, bias, num_chunks=5)
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.w_ih = Parameter(torch.randn(5 * hidden_size, input_size))
        self.w_hh0 = Parameter(torch.randn(5 * hidden_size, hidden_size))
        self.w_hh1 = Parameter(torch.randn(5 * hidden_size, hidden_size))
        self.b_ih = Parameter(torch.randn(5 * hidden_size))
        self.b_hh0 = Parameter(torch.randn(5 * hidden_size))
        self.b_hh1 = Parameter(torch.randn(5 * hidden_size))

    def forward(self, input: Tensor, state: Optional[Tuple[Tensor, Tensor, Tensor, Tensor]]) -> Tuple[Tensor, Tensor]:
        self.check_forward_input(input)
        self.check_forward_hidden(input, state[0], '[0]')
        self.check_forward_hidden(input, state[1], '[1]')
        self.check_forward_hidden(input, state[2], '[2]')
        self.check_forward_hidden(input, state[3], '[3]')
        return self.lstm_cell(
            input, state,
            self.w_ih, self.w_hh0, self.w_hh1,
            self.b_ih, self.b_hh0, self.b_hh1,
        )

    def lstm_cell(self, input, state, w_ih, w_hh0, w_hh1, b_ih, b_hh0, b_hh1):
        # type: (Tensor, Tuple[Tensor, Tensor, Tensor, Tensor], Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor]
        h_temp_prev, h_spat_prev, c_temp_prev, c_spat_prev = state
        gates = (torch.mm(input, w_ih.t())
                 + torch.mm(h_spat_prev, w_hh0.t())
                 + torch.mm(h_temp_prev, w_hh1.t()))
        gates += b_ih + b_hh0 + b_hh1

        in_gate, forget_gate_s, forget_gate_t, out_gate, u_gate = gates.chunk(5, 1)

        in_gate = torch.sigmoid(in_gate)
        forget_gate_s = torch.sigmoid(forget_gate_s)
        forget_gate_t = torch.sigmoid(forget_gate_t)
        out_gate = torch.sigmoid(out_gate)
        u_gate = torch.tanh(u_gate)

        cy = (in_gate * u_gate) + (forget_gate_s * c_spat_prev) + (forget_gate_t * c_temp_prev)
        hy = out_gate * torch.tanh(cy)

        return hy, cy

my model:

class STLSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, batch_size, dropout, classes_count):
        super(STLSTMModel, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.batch_size = batch_size
        self.st_lstm_cell_1 = STLSTMCell(input_size, hidden_size)
        self.st_lstm_cell_2 = STLSTMCell(hidden_size, hidden_size)
        self.fc = torch.nn.Linear(hidden_size, classes_count)
        self.dropout_l = nn.Dropout(dropout)

    def forward(self, input, state):
        h_next, c_next = self.st_lstm_cell_1(input, state)
        h_next = self.dropout_l(h_next)
        out, c_next = self.st_lstm_cell_2(h_next, state)
        out = self.fc(out)
        return h_next, c_next, F.log_softmax(out, dim=-1)

my training loop:

learning_rate = 0.002
momentum = 0.9
weight_decay = 0.95
dropout = 0.5

st_lstm_model = STLSTMModel(input_size, hidden_size, batch_size, dropout, len(classes)).to(device)

criterion = nn.NLLLoss()
optimizer = optim.SGD(st_lstm_model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay)

all_losses = []

for epoch in range(epoch_nb):
    data, train_y = get_data()
    tensor_train_y = torch.from_numpy(np.array(train_y)).to(device)

    optimizer.zero_grad()

    joints_count = data.shape[1]
    spatial_dim = joints_count
    temporal_dim = data.shape[0]

    cell1_out = [[[None, None] for _ in range(spatial_dim)] for _ in range(temporal_dim)]

    losses_arr = []

    for t in range(temporal_dim):
        for j in range(spatial_dim):
            if j == 0:
                h_spat_prev = torch.zeros(batch_size, hidden_size).to(device)
                c_spat_prev = torch.zeros(batch_size, hidden_size).to(device)
            else:
                h_spat_prev = cell1_out[t][j - 1][0]
                c_spat_prev = cell1_out[t][j - 1][1]
            if t == 0:
                h_temp_prev = torch.zeros(batch_size, hidden_size).to(device)
                c_temp_prev = torch.zeros(batch_size, hidden_size).to(device)
            else:
                h_temp_prev = cell1_out[t - 1][j][0]
                c_temp_prev = cell1_out[t - 1][j][1]
            state = STLSTMState(h_temp_prev, h_spat_prev, c_temp_prev, c_spat_prev)
            input = data[t][j]
            h_next, c_next, output = st_lstm_model(torch.tensor(input, dtype=torch.float, device=device), state)
            cell1_out[t][j][0] = h_next
            cell1_out[t][j][1] = c_next
            losses_arr.append(criterion(output, tensor_train_y))

    loss = 0
    for l in losses_arr:
        loss += l
    loss /= (spatial_dim * temporal_dim)

    loss.backward()
    optimizer.step()

I implemented this exactly as it’s described in article, but my training loss is not failing even after 5K iterations (it is around 2.4).

I tried also to use this same data and instead ST-LSTM use my custom implementation of LSTMCell where as input I passed Tensor with the entire action sequence (20 frames) containing all keypoints (12 kpts * (x,y,z) → 36 inputs)

torch.Size([5, 128]) - hidden state
torch.Size([5, 128]) - cell state
torch.Size([5, 20, 36]) - input

and it works like a charm - loss function is around 0.1 after 1K iterations and after 5K it’s around 0.0001

Do You know why my ST-LSTM model is not working? I spend several days trying to debug it and I have no idea what can be wrong.