Does for loop inside a forward pass impacts gradient computation?

I’m working on integrating dynamic batching into a Vision Transformer (ViT) + LSTM Network. This technique involves extracting features from a series of images, with the input vector being (Batch x Sequence x C x H x W). Given that sequence lengths vary, they are adjusted through padding with empty frames to maintain uniformity.

To optimize processing and avoid unnecessary computations on these padded frames, which could also impact negatively gradient computations, I’ve implemented a loop that selectively applies the ViT feature extractor only on valid frames.

My concern revolves around the potential implications this approach might have on fine-tuning the feature extractor, particularly in relation to the gradient computations.

I’m curious about whether applying the feature extractor to different frames in a loop could impact gradient calculations.

Below is the code snippet for reference:

class ViTLSTM_DynamicBatching(ViTLSTM):

    def forward(self, x, lengths):
        batch_size, seq_len, channels, height, width = x.shape

        # Apply the feature extractor to each frame

        max_length = max(lengths)
        device = x.device

        # Placeholder for extracted features
        features = torch.zeros(
            (batch_size, max_length, self.lstm_input_size), device=device
        )

        for i in range(batch_size):
            # Select non-padded frames for this sequence
            sequence_frames = x[i, : lengths[i]]
            # Flatten the sequence to fit the feature extractor input shape
            sequence_frames = sequence_frames.view(-1, channels, height, width)
            # Extract features
            sequence_features = self.feature_extractor(sequence_frames)[0][
                :, 0, :
            ]
            # Add the features to the placeholder
            features[i, : lengths[i]] = sequence_features

        features = features.view(batch_size, seq_len, -1)
        lengths = lengths.to("cpu").long()
        packed_features = nn.utils.rnn.pack_padded_sequence(
            features, lengths, batch_first=True, enforce_sorted=False
        )

        # Apply the LSTM
        packed_output, _ = self.ltsm(packed_features)
        ltsm_out, _ = nn.utils.rnn.pad_packed_sequence(
            packed_output, batch_first=True
        )
        # Apply pooling to the output of the LSTM
        pooled_out = torch.mean(ltsm_out, dim=1)

        # Apply the fc
        outputs = self.fc(pooled_out)
        return outputs

Note that in a production environment, the strategy involves directly passing the features stored in a buffer to the LSTM network, allowing both computations to be parallelized.