Requires_grad= False does not save memory

I expected that layers that don’t need to save gradients will require much less memory. But this is not the case somehow. to show this, i took the official MNIST example, added a couple of big conv layers with 9000 channels just to make it significant. Then tested memory in nvidia-smi intwo modes, one is freeze_2_conv_layers=True and the other is freeze_2_conv_layers=False. Both modes take exactly 4957MiB

from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Sequential(nn.Conv2d(10, 20, kernel_size=5), nn.Conv2d(20, 9000, kernel_size=5, padding=2), nn.Conv2d(9000, 20, kernel_size=5, padding=2))
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))

        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x,
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

def train(args, model, device, train_loader, optimizer, epoch):
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target =,
        output = model(data)
        loss = F.nll_loss(output, target)
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

def totest(args, model, device, test_loader):
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target =,
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
            pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=10, metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                        help='how many batches to wait before logging training status')
    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()


    device = torch.device("cuda" if use_cuda else "cpu")

    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    train_loader =
        datasets.MNIST('../data', train=True, download=True,
                           transforms.Normalize((0.1307,), (0.3081,))
        batch_size=args.batch_size, shuffle=True, **kwargs)
    test_loader =
        datasets.MNIST('../data', train=False, transform=transforms.Compose([
                           transforms.Normalize((0.1307,), (0.3081,))
        batch_size=args.test_batch_size, shuffle=True, **kwargs)

    model = Net().to(device)

    if freeze_2_conv_layers:
        for i,m in enumerate(model.modules()):
            if i==4 or i==5:
                m.weight.requires_grad = False
                m.bias.requires_grad = False
        optimizer = optim.SGD([par for par in model.parameters() if par.requires_grad],, momentum=args.momentum)
        optimizer = optim.SGD(model.parameters(),,

    for epoch in range(1, args.epochs + 1):
        train(args, model, device, train_loader, optimizer, epoch)
        totest(args, model, device, test_loader)

if __name__ == '__main__':
1 Like


The thing is that in your net, gradients are required for conv1, that means that pytorch needs to backprop through each module in conv2. So intermediary results for every operation in conv2 need to be kept around.
If conv1 was frozen or not here and if you don’t require gradients for the input, then you won’t need to backprop through conv2 and intermediary results won’t be saved.

Thanks, i see your point. Still, consider the nice example from pytorch’s tutorials:

import numpy as np

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random input and output data
x = np.random.randn(N, D_in)
y = np.random.randn(N, D_out)

# Randomly initialize weights
w1 = np.random.randn(D_in, H)
w2 = np.random.randn(H, D_out)

learning_rate = 1e-6
for t in range(500):
    # Forward pass: compute predicted y
    h =
    h_relu = np.maximum(h, 0)
    y_pred =

    # Compute and print loss
    loss = np.square(y_pred - y).sum()
    print(t, loss)

    # Backprop to compute gradients of w1 and w2 with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_w2 =
    grad_h_relu =
    grad_h = grad_h_relu.copy()
    grad_h[h < 0] = 0
    grad_w1 =

    # Update weights
    w1 -= learning_rate * grad_w1
    w2 -= learning_rate * grad_w2

Imagine in this case that you didn’t want to learn w2. So the variable grad_w2 does not need to be created and the program would occupy less memory if we skipped it.

Well yes the gradients tensor won’t be created. But that is actually not what uses a lot of memory. For example, check the memory usage after transfering a model to the gpu (so only the weights are there) and after running a forward pass (so weights + temporary buffer). You will see that the weights are a small fraction.
In your test, it is possible that these tensors are small enough that you don’t see they’re not there.
Keep in mind that the CUDA allocator will not return to the os memory just after a tensor is freed (for speed reason), you might want to use torch.cuda.memory_allocated() to get the exact memory usage by tensors.

1 Like

But we don’t need to use conv layer input to compute grad_input. And these inputs consume most of the memory. All we need to save is input.sizes().

But you do need these inputs to compute the gradient wrt to the weights (which are the Tensors that you are actually learning).
Also in the model above, the conv inputs are the output of relu. And the relu also needs this output to compute its gradient. So that Tensor is saved once for both ops: the conv wrt weights and the relu.

Surely you do need inputs to compute weights gradients. But in this case, weights have flag requires_grad=False so these gradients will not be computed.

I don’t sure if I understand ReLU argument. First of all, there are no activation layers in the given code for conv2 sequential block (which is strange). And second, in the general case, we have activation functions and we do need to know their inputs to calculate gradient but it is their job to save the input. But I see some complications then in-place activations used.

First of all, there are no activation layers in the given code for conv2 sequential block (which is strange).

Ho I was looking at the net in the first post. But indeed, if there is no activation, this is a moot argument.

it is their job to save the input.

Actually it is better to save the output when they can (which relu can do) because in things like
conv1 -> relu -> conv2 -> relu, etc. then the output of the conv don’t need to be saved, only the input! So you actually only save half the number of intermediary results compared to the case where the relu would save its input.

And yes inplace relu makes this point less important as the two intermediary results are just one and so just one is saved.

Surely you do need inputs to compute weights gradients. But in this case, weights have flag requires_grad=False so these gradients will not be computed.

One issue that arise here is that if you use optimized libraries like cudnn, they provide a single backward function for both input and weight. And we have no control over it in PyTorch. So, in some cases, even though the weights don’t require gradients, we need to save extra informations to be able to make these third_party library functions run properly.

Also in some cases, we could improve our codegen to be smarter on what needs to be saved, you can follow up the corresponding issue here:

Thank you for such an elaborate response. I have learned something today.

I have the last question. I am not good at c++ but I don’t see the behavior you describe in code

One issue that arise here is that if you use optimized libraries like cudnn, they provide a single backward function for both input and weight.

Did I search for it in the wrong place?

You can find the backward definitions in this file:
What is saved for backward is anything that is used in the formula these that correspond to an argument/output in the forward.

In particular you can see that all the convs there are defined in a special way to leverage these libraries (and re-use buffers I think?).
I have to admit I didn’t dive into these in a long time so there might be details I’m missing.