Out of Memory on few samples

Hello there

I am testing a meta learning and copy the program from https://medium.com/huggingface/from-zero-to-research-an-introduction-to-meta-learning-8e16e677f78a.

The problem is that after a certain amount of iterations I got an OOM message.

The model is not really large and I am testing with the MNIST dataset, so I suppose that there should be fine.
In the original code, the loss.backward is written without keep gradients. But if I erase that, then I have a problem with the lack of gradient. I suppose that the memory should be released after losses = [] but it seems that it is not.

Any advice to solve this problem.

Regards.

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from matplotlib import pyplot as plt
from torchvision import datasets, transforms
from sys import stdout


class MetaLearner(nn.Module):
    """ Bare Meta-learner class
        Should be added: intialization, hidden states, more control over everything
    """
    def __init__(self, model):
        super(MetaLearner, self).__init__()
        self.weights = nn.Parameter(torch.Tensor(1, 2))

    def forward(self, forward_model, backward_model):
        """ Forward optimizer with a simple linear neural net
        Inputs:
            forward_model: PyTorch module with parameters gradient populated
            backward_model: PyTorch module identical to forward_model (but without gradients)
              updated at the Parameter level to keep track of the computation graph for meta-backward pass
        """
        f_model_iter = get_params(forward_model)
        b_model_iter = get_params(backward_model)
        for f_param_tuple, b_param_tuple in zip(f_model_iter, b_model_iter): # loop over parameters
            # Prepare the inputs, we detach the inputs to avoid computing 2nd derivatives (re-pack in new Variable)
            (module_f, name_f, param_f) = f_param_tuple
            (module_b, name_b, param_b) = b_param_tuple
            inputs = torch.autograd.Variable(torch.stack([param_f.grad.data, param_f.data], dim=-1))
            dims = len(inputs.data.shape)
            # Optimization step: compute new model parameters, here we apply a simple linear function
            dW = F.linear(inputs, self.weights).squeeze(dims-1)
            param_b = param_b + dW
            # Update backward_model (meta-gradients can flow) and forward_model (no need for meta-gradients).
            module_b._parameters[name_b] = param_b
            param_f.data = param_b.data


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


def get_params(module, memo=None, pointers=None):
    """ Returns an iterator over PyTorch module parameters that allows to update parameters
        (and not only the data).
    ! Side effect: update shared parameters to point to the first yield instance
        (i.e. you can update shared parameters and keep them shared)
    Yields:
        (Module, string, Parameter): Tuple containing the parameter's module, name and pointer
    """
    if memo is None:
        memo = set()
        pointers = {}
    for name, p in module._parameters.items():
        if p not in memo:
            memo.add(p)
            pointers[p] = (module, name)
            yield module, name, p
        elif p is not None:
            prev_module, prev_name = pointers[p]
            module._parameters[name] = prev_module._parameters[prev_name] # update shared parameter pointer
    for child_module in module.children():
        for m, n, p in get_params(child_module, memo, pointers):
            yield m, n, p


def train(forward_model, backward_model, optimizer, meta_optimizer, train_data, meta_epochs, device):
  """ Train a meta-learner
  Inputs:
    forward_model, backward_model: Two identical PyTorch modules (can have shared Tensors)
    optimizer: a neural net to be used as optimizer (an instance of the MetaLearner class)
    meta_optimizer: an optimizer for the optimizer neural net, e.g. ADAM
    train_data: an iterator over an epoch of training data
    meta_epochs: meta-training steps
  To be added: intialization, early stopping, checkpointing, more control over everything
  """
  forward_model.train()
  backward_model.train()
  optimizer.train()
  for meta_epoch in range(meta_epochs): # Meta-training loop (train the optimizer)
    optimizer.zero_grad()
    losses = []
    print('#### Epoch: {}'.format(meta_epoch))
    for batch_idx, (inputs, labels) in enumerate(train_data):
        #for inputs, inputs in train_data:   # Meta-forward pass (train the model)
        stdout.write('Batch :{}\r'.format(batch_idx))
        stdout.flush()
        forward_model.zero_grad()         # Forward pass
        inputs, labels = inputs.to(device), labels.to(device)
        output = forward_model(inputs)
        loss = F.nll_loss(output, labels)
        losses.append(loss)
        loss.backward(retain_graph=True)                   # Backward pass to add gradients to the forward_model
        optimizer(forward_model,          # Optimizer step (update the models)
                  backward_model)
        if batch_idx % 50:
            meta_loss = sum(losses)             # Compute a simple meta-loss
            meta_loss.backward()                # Meta-backward pass
            meta_optimizer.step()               # Meta-optimizer step
            print(float(meta_loss.data))
            optimizer.zero_grad()
            losses = []
        if batch_idx > 1000:
            exit()
    # print(loss.data)
    


def main():
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

    model_fwd = Net().to(device)
    model_bwd = Net().to(device)
    optimizer = MetaLearner(model_fwd).to(device)
    meta_opt = optim.SGD(optimizer.parameters(), lr=0.1, momentum=0.5)

    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('./data', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=20, shuffle=True, **kwargs)
    train(model_fwd, model_bwd, optimizer, meta_opt, train_loader, 10, device)

if __name__ == '__main__':
    main()

Do you get the OOM error inside your DataLoader loop or were you able to pass a whole meta_epoch?
Currently, as you said, you are storing the losses in a list, which will store the whole computation graph with it, thus using more and more memory.
Do you need the losses somewhere else for your training?
If not and you just want them to store for debugging / printing, you should append them as losses.append(loss.item()).

Dear @ptrblck
Thank you for your prompt answer.

I got the error in the line:
loss.backward(retain_graph=True), The program breaks after some Iterations (for about 2000). This also includes the update of the meta_opt and the release of the losses.

If I try to use losses.append(loss.item()), the program then give the error:

File "meta_train.py", line 116, in train
    meta_loss.backward()                # Meta-backward pass
AttributeError: 'float' object has no attribute 'backward'

so I suppose that I still need the graph, at least until complete the meta_optimizer.step().