Pytorch, GRU memory leak

I am struggling with using GRU with the memory leak, here is the demo code

import tqdm
import torch as th
import torch.nn as nn
import torch.nn.functional as F

from torch.optim import Adam

class MyModel(nn.Module):
    def __init__(self, input_shape=348):
        super(MyModel, self).__init__()
        self.arrange = th.arange(55).int().cuda()

        self.fc1 = nn.Linear(input_shape, 32)
        self.rnn = nn.GRUCell(32, 32)
        self.fc2 = nn.Linear(32, 10*55)

    def forward(self, inputs, rnn_state):
        x = F.relu(self.fc1(inputs))
        rnn_state = rnn_state.reshape(-1, 32)
        h = self.rnn(x, rnn_state)
        logits = self.fc2(h)
        w = th.softmax(logits, dim=-1)
        alpha = th.div((th.argmax(w, dim=-1)+1).float(), float(w.shape[-1]))
        alpha_new = th.ceil(alpha*55-1).int()
        logits = logits.view(-1, 10, 55)
        values, indices = th.sort(logits, dim=2)
        masks = (self.arrange <= alpha_new[..., None]).float()
        masks = masks[:, None, :]
        values = th.sum(values * masks, dim=2)
        return values, h

def main():
    model = MyModel(input_shape=348)

    hidden_in = th.rand(1, 32).cuda()
    hidden_in = hidden_in.unsqueeze(0).expand(32, 27, -1)
    optimiser = Adam(params=model.parameters())
    for i in range(10000):
        inputs, inputs2 = th.rand(27*32, 348).cuda(), th.rand(27*32, 348).cuda()
        outputs, hidden_in = model(inputs, hidden_in)
        outputs2, _ = model(inputs2, hidden_in)

        loss = th.mean((outputs - outputs2.detach()) ** 2)

if __name__ == '__main__':

While training, the GPU memory usage is increasing. However, after removing the RNN, GPU memory usage is stable.

If you use


you in every loop, it should not surprising that nothing gets freed up (just from the literal meaning of “retain” – not to say you should have seen known before but maybe it is a good thing to remember).
What happens is if you don’t have that is that the thing says “trying to backward twice blabla”, and that is true!
In the end, you need to detach hidden_in at the end of your loop before starting the next to “forget what has happened”.

Best regards