CNN-RNN - how to manage gradients?


I’m trying to create a system that will perform action recognition. I’m trying to use a combination of CNN and RNN. For CNN I am using ShuffleNet v2. I replaced the fc of the model with my custom nn.Sequential:

    rcnn_layer = nn.Sequential(
        nn.Linear(1024, 128),
        View((1, 128)),
        nn.RNN(input_size=128, hidden_size=64),
        nn.Linear(64, 3),

The custom classes are:
View - reshapes the output of the first linear layer so that it fits the RNN
InputModifier - retains previous outputs of the model
GetLastHidden - returns last layer of the RNN

In the future I’d like to use the model on a phone (somehow), so I’m trying to make it as efficient as I can. Therefore, in order to make it run faster, in my InputModifier class I’m retaining the previous outputs of the CNN. Here’s the code:

class InputModifier(nn.Module):
    prev = []

    def __init__(self, max_seq_len):
        assert max_seq_len != 0, '`max_seq_len` cannot be 0.'
        super(InputModifier, self).__init__()
        self.max_seq_len = max_seq_len

    def forward(self, x):
        if self.max_seq_len > 0:
            self.prev = self.prev[-self.max_seq_len:]

        inp =
        return inp

The issue is: I don’t know how to get it to work. When I run loss.backward(), I get RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. error.

If i set loss.backward(retain_graph=True), I get RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [1024, 128]], which is output 0 of TBackward, is at version 2; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True). .

PS: The training code:

model =
optim = torch.optim.Adam(lr=0.01, params=model.parameters())
criterion = nn.CrossEntropyLoss()
for i, (img, c) in enumerate(dataloader):

    out = model(img)
    loss = criterion(out, c.view(1))


All advices are welcome. Thanks.