[Memory problem] Replace input by another tensor in the forward pass

Hi everyone,

I am trying to extend a Linear Function by the following link: torch.autograd .

In the forward pass, my goal is to discard the input and replace it by another tensors, i.e., a, b, c, after performing some operations. After computing a, b, c, I want PyTorch not to save_for_backward(input), but save_for_backward(a, b, c) instead.

In the backward pass, I use a, b, c to reconstruct the input, that was removed in the forward pass, and continue with the rest as in the original Function.

However, using nvidia-smi command, I see that the memory used in the GPU using MyLinearFunction is even greater than that using LinearFunction, although the total size of a, b, c is much smaller than the input.

What did I do wrong? You can see the codes below.

Thank you so much for your help!

class LinearFunction(Function):

    # Note that both forward and backward are @staticmethods
    @staticmethod
    # bias is an optional argument
    def forward(ctx, input, weight, bias=None):
        output = input.mm(weight.t())
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)

        ctx.save_for_backward(input)

        return output

    # This function has only a single output, so it gets only one gradient
    @staticmethod
    def backward(ctx, grad_output):
        # This is a pattern that is very convenient - at the top of backward
        # unpack saved_tensors and initialize all gradients w.r.t. inputs to
        # None. Thanks to the fact that additional trailing Nones are
        # ignored, the return statement is simple even when the function has
        # optional inputs.
        input = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None

        # These needs_input_grad checks are optional and there only to
        # improve efficiency. If you want to make your code simpler, you can
        # skip them. Returning gradients for inputs that don't require it is
        # not an error.
        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0).squeeze(0)

        return grad_input, grad_weight, grad_bias
class MyLinearFunction(Function):

    # Note that both forward and backward are @staticmethods
    @staticmethod
    # bias is an optional argument
    def forward(ctx, input, weight, bias=None):
        output = input.mm(weight.t())
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)

        a, b, c = operation(input)

        ctx.save_for_backward(a, b, c, weight, bias)

        del input
        torch.cuda.empty_cache()

        return output

    # This function has only a single output, so it gets only one gradient
    @staticmethod
    def backward(ctx, grad_output):
        # This is a pattern that is very convenient - at the top of backward
        # unpack saved_tensors and initialize all gradients w.r.t. inputs to
        # None. Thanks to the fact that additional trailing Nones are
        # ignored, the return statement is simple even when the function has
        # optional inputs.
        a, b, c, weight, bias = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None

        input = reconstruct(a,b,c)

        del a,b,c
        torch.cuda.empty_cache()        

        # These needs_input_grad checks are optional and there only to
        # improve efficiency. If you want to make your code simpler, you can
        # skip them. Returning gradients for inputs that don't require it is
        # not an error.
        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0).squeeze(0)

        return grad_input, grad_weight, grad_bias

Hi,

I guess it’s a typo but in the first Function, you forgot to save weight and bias.
a,b and c are not inputs or outputs so you can’t give them to save_for_backward() (it should crash no?). You can save them to ctx.abc = (a,b,c) or ctx.a = a, ctx.b = b etc.

How do you test this code? It is possible that another Function holds onto input and so input is still kept alive.
Do you have a full code sample that reproduce your issue?

Hi Alban,

I created MyNet, that used MyLinear module. As I stated before, when I printed

        print(pytorch2numpy(input).nbytes/1024/1024,
              pytorch2numpy(core).nbytes/1024/1024,
              pytorch2numpy(tucker_factors[0]).nbytes/1024/1024,
              pytorch2numpy(tucker_factors[1]).nbytes/1024/1024,
              )

the outputs are:

1.5625 0.0244140625 0.125 0.30517578125

which means that:

1.5625 > 0.0244140625 + 0.125 + 0.30517578125

I expected that when I used nvidia-smi, the memory used in MyNet should be smaller than Net. In fact, the memory used in these models were 659MB vs 619MB, respectively.

I am using:

torchvision                0.2.1
pytorch                    1.0.1 
python                     3.6.7
tensorly                   0.4.4

Please see the codes below for your reference:

layer.py

import gc
import torch
import math
import numpy as np
import tensorly as tl
from tensorly.decomposition import tucker
from torch.nn.parameter import Parameter
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Function
from torch.nn.modules.utils import _single, _pair, _triple
from torch.nn.modules.module import Module
from torch.nn import init


tl.set_backend('pytorch')


def pytorch2numpy(a):
    return a.detach().cpu().numpy()


# Inherit from Function
class MyLinearFunction(Function):

    # Note that both forward and backward are @staticmethods
    @staticmethod
    # bias is an optional argument
    def forward(ctx, input, weight, bias=None):
        output = input.mm(weight.t())
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)

        shape = input.size()
        ranks = [int(shape[0]/8), int(shape[1]/8)]
        tl_img = tl.tensor(input, device='cuda:0')

        core, tucker_factors = tl.decomposition.tucker(
            tl_img, ranks=ranks, init='random', tol=100e-5, n_iter_max=100)

        print(pytorch2numpy(input).nbytes/1024/1024,
              pytorch2numpy(core).nbytes/1024/1024,
              pytorch2numpy(tucker_factors[0]).nbytes/1024/1024,
              pytorch2numpy(tucker_factors[1]).nbytes/1024/1024,
              )

        del input, tl_img
        torch.cuda.empty_cache()
        gc.collect()

        ctx.core = core
        ctx.factor0 = tucker_factors[0]
        ctx.factor1 = tucker_factors[1]

        ctx.save_for_backward(weight, bias)

        return output

    # This function has only a single output, so it gets only one gradient
    @staticmethod
    def backward(ctx, grad_output):
        # This is a pattern that is very convenient - at the top of backward
        # unpack saved_tensors and initialize all gradients w.r.t. inputs to
        # None. Thanks to the fact that additional trailing Nones are
        # ignored, the return statement is simple even when the function has
        # optional inputs.
        grad_input = grad_weight = grad_bias = None

        core = ctx.core
        factor0 = ctx.factor0
        factor1 = ctx.factor1

        weight, bias = ctx.saved_tensors

        input = tl.tucker_to_tensor(
            [core, [factor0, factor1]])

        del core, factor0, factor1
        torch.cuda.empty_cache()
        gc.collect()

        # These needs_input_grad checks are optional and there only to
        # improve efficiency. If you want to make your code simpler, you can
        # skip them. Returning gradients for inputs that don't require it is
        # not an error.
        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0).squeeze(0)

        return grad_input, grad_weight, grad_bias


class MyLinear(nn.Module):
    def __init__(self, input_features, output_features, bias=True):
        super(MyLinear, self).__init__()
        self.input_features = input_features
        self.output_features = output_features

        # nn.Parameter is a special kind of Tensor, that will get
        # automatically registered as Module's parameter once it's assigned
        # as an attribute. Parameters and buffers need to be registered, or
        # they won't appear in .parameters() (doesn't apply to buffers), and
        # won't be converted when e.g. .cuda() is called. You can use
        # .register_buffer() to register buffers.
        # nn.Parameters require gradients by default.
        self.weight = nn.Parameter(
            torch.Tensor(output_features, input_features))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(output_features))
        else:
            # You should always register all possible parameters, but the
            # optional ones can be None if you want.
            self.register_parameter('bias', None)

        # Not a very smart way to initialize weights
        self.weight.data.uniform_(-0.1, 0.1)
        if bias is not None:
            self.bias.data.uniform_(-0.1, 0.1)

    def forward(self, input):
        # See the autograd section for explanation of what happens here.
        return MyLinearFunction.apply(input, self.weight, self.bias)

    def extra_repr(self):
        # (Optional)Set the extra information about this module. You can test
        # it by printing an object of this class.
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None
        )

main.py

import numpy as np
from torch.nn import init
from torch.nn.modules.utils import _pair
from torch.autograd import Function
from torchvision import datasets, transforms
from layer import MyLinear
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
import torch
import math
import argparse

import multiprocessing
multiprocessing.set_start_method('spawn', True)


class MyNet(nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = MyLinear(4*4*50, 500)
        self.fc2 = MyLinear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


def train(args, model, device, train_loader, optimizer, epoch):
    # reporter = MemReporter()

    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()

        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))


def test(args, model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            # sum up batch loss
            test_loss += F.nll_loss(output, target, reduction='sum').item()
            # get the index of the max log-probability
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))


def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size', type=int, default=512, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=512, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=10, metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                        help='how many batches to wait before logging training status')

    parser.add_argument('--save-model', action='store_true', default=False,
                        help='For Saving the current Model')
    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda")

    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=args.batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])),
        batch_size=args.test_batch_size, shuffle=True, **kwargs)

    # model = Net().to(device)
    model = MyNet().to(device)

    optimizer = optim.SGD(model.parameters(), lr=args.lr,
                          momentum=args.momentum)

    for epoch in range(1, args.epochs + 1):
        train(args, model, device, train_loader, optimizer, epoch)
        test(args, model, device, test_loader)

    if (args.save_model):
        torch.save(model.state_dict(), "mnist_cnn.pt")


if __name__ == '__main__':
    main()

Thank you so much for your help!

Hi,

Doing

del input
torch.cuda.empty_cache()
gc.collect()

in your code has very little impact as the Tensor pointed by input is still referenced in the forward call of the nn Module and so has not gone out of scope yet. So nothing can really be cleaned up.

Then the issue is most likely that other things hold onto input. For example, I think the view will keep the input of fc1 alive. So for this one, you won’t use less memory, only the extra stuff to store your new Tensors.
For fc2 it seems that the out of place relu does not hold onto its output. But I am not 100% sure this is the cpp function you call here or if it’s threshold.

Hi Alban,

Here, my idea is to replace input, which supposes to consume a lot of memory (if the mini batch-size is set to be large), by several compact tensors, which will be used to estimate the input later in the backward pass. Is there any ways that kill input in the forward pass, and utilize these compact tensors to reconstruct input in the backward pass?

Thank you so much for your reply.

You will have to modify all the autograd.Functions that save it to force them to save your compressed version.
You did it properly for Linear, but there might be other Functions, like ReLU that also save it and should be modified.

1 Like

Thank you so much for your help, Alban. Really appreciate it :wink:

Hi @albanD,

I am trying to implement a simple model that has only one customized linear layer that you saw earlier. Since there is a single layer, I expect that: (1) there is any tensors that hold onto input, (2) if I delete input in the forward pass, the GPU supposes to release input and free some memory. However, input seems to exist no matter what…

Would you mind taking a look?

from torch.nn.modules.module import Module
from torch.nn.modules.utils import _single, _pair, _triple
from torch.nn.parameter import Parameter
from tensorly.decomposition import tucker
import tensorly as tl
import gc
import numpy as np
from torch.nn import init
from torch.autograd import Function
from torchvision import datasets, transforms
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
import torch
import math
import argparse
from pytorch_memlab import profile
from pytorch_memlab.mem_reporter import MemReporter


import multiprocessing
multiprocessing.set_start_method('spawn', True)


tl.set_backend('pytorch')


def numpy2pytorch(a, device=torch.device("cuda")):
    return torch.from_numpy(a.copy()).to(device)


def pytorch2numpy(a):
    return a.detach().cpu().numpy()


# Inherit from Function
class MyLinearFunction(Function):

    # Note that both forward and backward are @staticmethods
    @staticmethod
    # bias is an optional argument
    def forward(ctx, input, weight, bias=None):
        output = input.mm(weight.t())
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)

        shape = input.size()
        ranks = [int(shape[0]/8), int(shape[1]/100)]
        tl_img = tl.tensor(input, device='cuda:0')

        core, tucker_factors = tl.decomposition.tucker(
            tl_img, ranks=ranks, init='random', tol=100e-5, n_iter_max=100)

        print(pytorch2numpy(input).nbytes/1024/1024,
              pytorch2numpy(core).nbytes/1024/1024,
              pytorch2numpy(tucker_factors[0]).nbytes/1024/1024,
              pytorch2numpy(tucker_factors[1]).nbytes/1024/1024,
              )

        del input, tl_img
        torch.cuda.empty_cache()
        gc.collect()

        ctx.core = core
        ctx.factor0 = tucker_factors[0]
        ctx.factor1 = tucker_factors[1]

        ctx.save_for_backward(weight, bias)

        return output

    # This function has only a single output, so it gets only one gradient
    @staticmethod
    def backward(ctx, grad_output):
        # This is a pattern that is very convenient - at the top of backward
        # unpack saved_tensors and initialize all gradients w.r.t. inputs to
        # None. Thanks to the fact that additional trailing Nones are
        # ignored, the return statement is simple even when the function has
        # optional inputs.
        grad_input = grad_weight = grad_bias = None

        core = ctx.core
        factor0 = ctx.factor0
        factor1 = ctx.factor1

        weight, bias = ctx.saved_tensors

        input = tl.tucker_to_tensor(
            [core, [factor0, factor1]])

        del core, factor0, factor1
        torch.cuda.empty_cache()
        gc.collect()

        # These needs_input_grad checks are optional and there only to
        # improve efficiency. If you want to make your code simpler, you can
        # skip them. Returning gradients for inputs that don't require it is
        # not an error.
        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0).squeeze(0)

        return grad_input, grad_weight, grad_bias


class MyLinear(nn.Module):
    def __init__(self, input_features, output_features, bias=True):
        super(MyLinear, self).__init__()
        self.input_features = input_features
        self.output_features = output_features

        # nn.Parameter is a special kind of Tensor, that will get
        # automatically registered as Module's parameter once it's assigned
        # as an attribute. Parameters and buffers need to be registered, or
        # they won't appear in .parameters() (doesn't apply to buffers), and
        # won't be converted when e.g. .cuda() is called. You can use
        # .register_buffer() to register buffers.
        # nn.Parameters require gradients by default.
        self.weight = nn.Parameter(
            torch.Tensor(output_features, input_features))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(output_features))
        else:
            # You should always register all possible parameters, but the
            # optional ones can be None if you want.
            self.register_parameter('bias', None)

        # Not a very smart way to initialize weights
        self.weight.data.uniform_(-0.1, 0.1)
        if bias is not None:
            self.bias.data.uniform_(-0.1, 0.1)

    def forward(self, input):
        # See the autograd section for explanation of what happens here.
        return MyLinearFunction.apply(input, self.weight, self.bias)

    def extra_repr(self):
        # (Optional)Set the extra information about this module. You can test
        # it by printing an object of this class.
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None
        )


class MyNet(nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        self.fc2 = MyLinear(10000, 10)

    def forward(self, x):
        x = self.fc2(x)
        return x


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc2 = nn.Linear(10000, 10)

    def forward(self, x):
        x = self.fc2(x)
        return x


def train(args, model, device, train_loader, optimizer, epoch):
    reporter = MemReporter()

    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        # data, target = data.to(device), target.to(device))

        data, target = np.random.rand(512,10000).astype(np.float32), np.random.randint(2, size=512)
        data, target = numpy2pytorch(data).to(device), numpy2pytorch(target).to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)

        print('========= before backward =========')
        reporter.report()
        
        loss.backward()

        print('========= after backward =========')
        reporter.report()

        optimizer.step()

        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))


def test(args, model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            # sum up batch loss
            test_loss += F.nll_loss(output, target, reduction='sum').item()
            # get the index of the max log-probability
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))


def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size', type=int, default=512, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=512, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=10, metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                        help='how many batches to wait before logging training status')

    parser.add_argument('--save-model', action='store_true', default=False,
                        help='For Saving the current Model')
    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda")

    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=args.batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])),
        batch_size=args.test_batch_size, shuffle=True, **kwargs)

    model = Net().to(device)
    # model = MyNet().to(device)

    optimizer = optim.SGD(model.parameters(), lr=args.lr,
                          momentum=args.momentum)

    for epoch in range(1, args.epochs + 1):
        train(args, model, device, train_loader, optimizer, epoch)
        test(args, model, device, test_loader)

    if (args.save_model):
        torch.save(model.state_dict(), "mnist_cnn.pt")


if __name__ == '__main__':
    main()

Memory consumption in both cases can be seen below:

with default linear layer

========= before backward =========
/home/minhvu/anaconda3/envs/vqa/lib/python3.6/site-packages/torch/distributed/distributed_c10d.py:86: UserWarning: torch.distributed.reduce_op is deprecated, please use torch.distributed.ReduceOp instead
  warnings.warn("torch.distributed.reduce_op is deprecated, please use "
Element type                                            Size  Used MEM
-------------------------------------------------------------------------------
Storage on cuda:0
Parameter0                                       (10, 10000)   391.00K
Parameter1                                             (10,)   512.00B
Tensor2                                         (512, 10000)    19.53M
Tensor3                                               (512,)     4.00K
Tensor4                                            (512, 10)    20.00K
Tensor5                                                 (1,)   512.00B
-------------------------------------------------------------------------------
Total Tensors: 5225643  Used Memory: 19.94M
The allocated memory on cuda:0: 20.05M
Memory differs due to the matrix alignment or invisible gradient buffer tensors
-------------------------------------------------------------------------------
-------------------------------------------------------------------------------
Storage on cpu
Tensor6                                      (60000, 28, 28)    44.86M
Tensor7                                             (60000,)   469.00K
Tensor8                                      (10000, 28, 28)     7.48M
Tensor9                                             (10000,)    78.50K
Tensor10                                    (512, 1, 28, 28)     1.53M
Tensor11                                              (512,)     4.00K
Tensor12                                    (512, 1, 28, 28)     1.53M
Tensor13                                              (512,)     4.00K
Tensor14                                              (512,)     4.00K
Tensor15                                    (512, 1, 28, 28)     1.53M
Tensor16                                              (512,)     4.00K
Tensor17                                    (512, 1, 28, 28)     1.53M
-------------------------------------------------------------------------------
Total Tensors: 56557680         Used Memory: 59.01M
-------------------------------------------------------------------------------
========= after backward =========
Element type                                            Size  Used MEM
-------------------------------------------------------------------------------
Storage on cuda:0
Parameter0                                       (10, 10000)   391.00K
Parameter0.grad                                  (10, 10000)   391.00K
Parameter1                                             (10,)   512.00B
Parameter1.grad                                        (10,)   512.00B
Tensor2                                         (512, 10000)    19.53M
Tensor3                                               (512,)     4.00K
Tensor4                                            (512, 10)    20.00K
Tensor5                                                 (1,)   512.00B
-------------------------------------------------------------------------------
Total Tensors: 5325653  Used Memory: 20.32M
The allocated memory on cuda:0: 20.41M
Memory differs due to the matrix alignment or invisible gradient buffer tensors
-------------------------------------------------------------------------------
-------------------------------------------------------------------------------
Storage on cpu
Tensor6                                      (60000, 28, 28)    44.86M
Tensor7                                             (60000,)   469.00K
Tensor8                                      (10000, 28, 28)     7.48M
Tensor9                                             (10000,)    78.50K
Tensor10                                    (512, 1, 28, 28)     1.53M
Tensor11                                              (512,)     4.00K
Tensor12                                    (512, 1, 28, 28)     1.53M
Tensor13                                              (512,)     4.00K
Tensor14                                              (512,)     4.00K
Tensor15                                    (512, 1, 28, 28)     1.53M
Tensor16                                              (512,)     4.00K
Tensor17                                    (512, 1, 28, 28)     1.53M
-------------------------------------------------------------------------------
Total Tensors: 56557680         Used Memory: 59.01M
-------------------------------------------------------------------------------

with customized linear layer

========= before backward =========
/home/minhvu/anaconda3/envs/vqa/lib/python3.6/site-packages/torch/distributed/distributed_c10d.py:86: UserWarning: torch.distributed.reduce_op is deprecated, please use torch.distributed.ReduceOp instead
  warnings.warn("torch.distributed.reduce_op is deprecated, please use "
Element type                                            Size  Used MEM
-------------------------------------------------------------------------------
Storage on cuda:0
Tensor0                                                 (1,)   512.00B
Parameter1                                       (10, 10000)   391.00K
Parameter2                                             (10,)   512.00B
Tensor3                                         (512, 10000)    19.53M
Tensor4                                               (512,)     4.00K
Tensor5                                            (512, 10)    20.00K
Tensor6                                            (64, 100)    25.00K
Tensor7                                         (10000, 100)     3.81M
Tensor8                                            (512, 64)   128.00K
-------------------------------------------------------------------------------
Total Tensors: 6264811  Used Memory: 23.90M
The allocated memory on cuda:0: 24.06M
Memory differs due to the matrix alignment or invisible gradient buffer tensors
-------------------------------------------------------------------------------
-------------------------------------------------------------------------------
Storage on cpu
Tensor9                                      (10000, 28, 28)     7.48M
Tensor10                                            (10000,)    78.50K
Tensor11                                    (512, 1, 28, 28)     1.53M
Tensor12                                              (512,)     4.00K
Tensor13                                    (512, 1, 28, 28)     1.53M
Tensor14                                              (512,)     4.00K
Tensor15                                              (512,)     4.00K
Tensor16                                    (512, 1, 28, 28)     1.53M
Tensor17                                              (512,)     4.00K
Tensor18                                    (512, 1, 28, 28)     1.53M
Tensor19                                     (60000, 28, 28)    44.86M
Tensor20                                            (60000,)   469.00K
-------------------------------------------------------------------------------
Total Tensors: 56557680         Used Memory: 59.01M
-------------------------------------------------------------------------------
========= after backward =========
Element type                                            Size  Used MEM
-------------------------------------------------------------------------------
Storage on cuda:0
Parameter1                                       (10, 10000)   391.00K
Parameter1.grad                                  (10, 10000)   391.00K
Parameter2                                             (10,)   512.00B
Parameter2.grad                                        (10,)   512.00B
Tensor3                                         (512, 10000)    19.53M
Tensor4                                               (512,)     4.00K
Tensor5                                            (512, 10)    20.00K
Tensor6                                            (64, 100)    25.00K
Tensor7                                         (10000, 100)     3.81M
Tensor8                                            (512, 64)   128.00K
Tensor0                                                 (1,)   512.00B
-------------------------------------------------------------------------------
Total Tensors: 6364821  Used Memory: 24.28M
The allocated memory on cuda:0: 24.44M
Memory differs due to the matrix alignment or invisible gradient buffer tensors
-------------------------------------------------------------------------------
-------------------------------------------------------------------------------
Storage on cpu
Tensor9                                      (10000, 28, 28)     7.48M
Tensor10                                            (10000,)    78.50K
Tensor11                                    (512, 1, 28, 28)     1.53M
Tensor12                                              (512,)     4.00K
Tensor13                                    (512, 1, 28, 28)     1.53M
Tensor14                                              (512,)     4.00K
Tensor15                                              (512,)     4.00K
Tensor16                                    (512, 1, 28, 28)     1.53M
Tensor17                                              (512,)     4.00K
Tensor18                                    (512, 1, 28, 28)     1.53M
Tensor19                                     (60000, 28, 28)    44.86M
Tensor20                                            (60000,)   469.00K
-------------------------------------------------------------------------------
Total Tensors: 56557680         Used Memory: 59.01M
-------------------------------------------------------------------------------

I think my comment above is relevant:

Here, the input to your model, data is the same Tensor as input that you get during forward. But even if you don’t store input, you still hold onto data in your training loop, so it cannot be freed.
Changing to

output = mode(data)
del data

should help.
Also if and only if you introduce reference cycles, you need to gc.collect() to make sure things are freed.

Let me know if it works the way you want :slight_smile:

1 Like

@albanD thank you so much for pointing these issues out. It works like charm!

Have a nice weekend ahead.

Best regards,
Minh

1 Like