Misunderstanding how detach() works with LSTM training

Hi all, I have a quick question about using detach on a long video with LSTM layers. My problem is this; My video is around a minute long and so I need to perform backpropogation every 7 seconds or so. I’ve searched the forums and most other implementations I’ve seen always clear the hidden states after performing backpropogation. What would happen if I didn’t clear those states and just kept passing the hidden states in until the end of the video? here’s an example:

My model is:

class Classifier(nn.Module):

    def __init__(self, frame_length=70, dropout=0.5):
        super(Classifier, self).__init__()
        self.drop_rate = dropout

        self.conv = nn.Conv2d(3, 128, kernel_size=(3,3), padding=(1,1))
        self.bn = nn.BatchNorm2d(128)
        self.pool = nn.AdaptiveAvgPool2d(1)

        self.rnn = nn.LSTM(128, 128, 1, batch_first=False, dropout=dropout)

        self.fc = nn.Linear(128 * frame_length, 128) # hidden size * number of frames
        self.output_layer = nn.Linear(128, 2)
        self.activation = torch.nn.ReLU()


    def reset_states(self, num_layers=1, batch_size=1, hidden_size=128):

        self.hidden1 = [
            torch.zeros(num_layers, batch_size, hidden_size).requires_grad_(),
            torch.zeros(num_layers, batch_size, hidden_size).requires_grad_()
        ]

    def forward(self, input, detach=True):

        batch_size, timesteps, C, H, W = input.size()

        output = input.view(batch_size * timesteps, C, H, W) # change size for convolutional layer

        output = self.conv(output)
        output = self.bn(output)
        output = self.activation(output)
        output = self.pool(output)

        output = output.view(timesteps, batch_size, -1) # change size for recurrent layer
        output, hidden1 = self.rnn(output, self.hidden1)
        
        # Pass the hidden states on to the next forward pass
        if detach:
            self.hidden1[0] = hidden1[0].detach()
            self.hidden1[1] = hidden1[1].detach()

        output = output.reshape(output.shape[1], -1)
        output = self.activation(self.fc(output))
        output = TF.dropout(output, self.drop_rate)

        output = self.output_layer(output)

        return output

And then in my training loop:


for i in range(0, 10):
    for c in range(0, len(dataset):
        video, label = dataset[c]

        net.reset_states() # Only reset the hidden states here, before a new video is loaded
        
        """ For brevity's sake, lets just pretend this for loop gets 70 frames out of the video every iteration
             and applies any transforms Also lets pretend the video is 10fps making this 7 seconds of video"""
        for frames in video:
             batch = frames.unsqueeze(0) # add batch dimension, only using a single batch for simplicity
             output = self.net(batch)
             loss = self.criterion(output, label)
             loss.backward()
             self.optim.step()
             self.optim.zero_grad()

Sorry if there’s any errors in the code, I’m just quickly making an example up out of thin air. What would happen in this case, would each new batch still have the context of the previous batch? I understand that gradients would only be tracked for the last 7 seconds of video but wouldn’t it also have the hidden states from the rest of the video to consider?

I’ve also seen some implementations of TBPTT which would work but am not entirely certain how to do this with convolutional layers and linear layers mixed in.