Detaching parameters and ballooning memory

EDIT: don’t detach_() parameters without also resetting param.requires_grad = True.

I have a very weird case of ballooning memory with a custom-built rnn unit. During each batch the memory usage balloons if I don’t .detach_() the rnn unit’s parameters between epochs. The same also occurs if I put a Linear layer before the rnn unit.

I run my models on the CPU using pytorch version 0.3.1.post2. I “measure” the memory use by eyeballing my memory monitor widget.

During a normal run, my full script uses ~500Mb memory, but under the above conditions the memory usage goes up ~4Gb more during the first batch, it goes down at the end of each batch only to go back up again during the next batch.

I tried to code a minimal version, but I can’t get the memory usage to balloon. However the training time doubles or triples under the same conditions that cause the memory to balloon with my full code.

Here it is…

import time
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim

class TRNN(nn.Module):
    """Strongly typed RNN from https://arxiv.org/abs/1602.02218 with the bias removed"""
    def __init__(self, input_size, hidden_size, detach=True):
        super(TRNN, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.detach = detach
        self.weight_ih = nn.Parameter(torch.Tensor(2 * hidden_size, input_size))
        self.reset_parameters()
        
    def __repr__(self):
        s = '{name}({input_size}, {hidden_size}, detach={detach})'
        return s.format(name=self.__class__.__name__, **self.__dict__)

    def reset_parameters(self):
        stdv = 1.0 / math.sqrt(self.hidden_size)
        self.weight_ih.data.uniform_(-stdv, stdv)
        self.hidden = None
    
    def reset_hidden(self):
        self.hidden = None
        if self.detach:
            self.weight_ih.detach_()
    
    def detach_hidden(self):
        self.hidden.detach_()

    def forward(self, input_data):
        timesteps, batch_size, features = input_data.size()
        outputs = Variable(torch.zeros(timesteps, batch_size, self.hidden_size))
        
        if self.hidden is None:
            self.hidden = Variable(torch.zeros(batch_size, self.hidden_size))
        
        for i, input_t in enumerate(input_data.split(1)):
            
            gi = F.linear(input_t.view(batch_size, features), self.weight_ih, None)
            i_n, i_f = gi.chunk(2, 1)

            forgetgate = F.sigmoid(i_f)
            newgate = i_n
            self.hidden = newgate + forgetgate * (self.hidden - newgate)
            outputs[i] = self.hidden
        
        return outputs


def reset_hidden(layer):
    if hasattr(layer, "reset_hidden"):
        layer.reset_hidden()

def detach_hidden(layer):
    if hasattr(layer, "detach_hidden"):
        layer.detach_hidden()


def train(model, optimizer, batches):
    start = time.time()
    for epoch in range(5):
        model.apply(reset_hidden)
        for inputs, targets in batches:
            output = model(Variable(inputs))
            optimizer.zero_grad()
            loss = F.mse_loss(output, Variable(targets))
            loss.backward()
            optimizer.step()
            model.apply(detach_hidden)
    
    print("Training done in", time.time()-start, "seconds")
    print()


batch_size = 32
seq_len = 500
features = 100
targets = 10

batches = []
for i in range(10):
    batches.append((
        torch.randn(batch_size, seq_len, features),
        torch.randn(batch_size, seq_len, targets)
        ))

print("Don't detach parameters")
model = nn.Sequential(
    TRNN(features, 200, detach=False), 
    nn.Linear(200, targets))
optimizer = optim.Adam(model.parameters(), lr=0.01)
train(model, optimizer, batches)

print("Do detach parameters")
model = nn.Sequential(
    TRNN(features, 200, detach=True), 
    nn.Linear(200, targets))
optimizer = optim.Adam(model.parameters(), lr=0.01)
train(model, optimizer, batches)

print("Linear first and do detach")
model = nn.Sequential(
    nn.Linear(features, 100),
    TRNN(features, 200, detach=False), 
    nn.Linear(200, targets))
optimizer = optim.Adam(model.parameters(), lr=0.01)
train(model, optimizer, batches)

Output on my lowly old laptop…

Don’t detach parameters
Training done in 33.29167413711548 seconds

Do detach parameters
Training done in 13.432651281356812 seconds

Linear first and do detach
Training done in 40.65597105026245 seconds

Any ideas?

If no one can tell me what mistake I have made I shall assume this is a bug and cross-post in a github issue.