Large memory footprint of torch model

Why does torch graph consume a large memory footprint. I get a spike of 7 GB (gpu memory) by instantiating the following model:

class Model(torch.nn.Module):
    # Define model
    def __init__(self, args, mean, std):
        super(Model, self).__init__()

        self.numCells = args.numCells
        self.mean = Variable(mean, requires_grad = False)
        self.std = Variable(std, requires_grad = False)

        self.conv1 = nn.Conv2d(3, 64, 5, padding = 2)
        self.conv2 = nn.Conv2d(64, 128, 5, padding=2)
        self.conv3 = nn.Conv2d(128, 128, 3, padding = 1)
        self.conv4 = nn.Conv2d(128, 128, 3, padding =1)
        self.conv5 = nn.Conv2d(128, 256, 3, padding = 1)
        self.conv6 = nn.Conv2d(256, 256, 3, padding = 1)
        self.conv7 = nn.Conv2d(256, 512, 3, padding = 1)
        self.conv8 = nn.Conv2d(512, 512, 3, padding = 1)
        self.conv9 = nn.Conv2d(512, 512, 5, padding = 2)
        self.fc = nn.Linear(32 * 32 * 512, self.numCells * self.numCells * 7)

Depends on the value of numCells. For numCells = 21, I’d expect about 6.5 GB of parameters just in self.fc.

Ah, actually I see it now.

By the way does defining a loss on this net and doing loss.backward() create a copy of the parameters? I run out of memory on doing loss.backward. My gpu has 12 Gigs of memory and before the call to loss.backward I have exhausted 10 gigs.

The gradients are the same size as the model parameters. So if you have 10 GB of parameters you need another 10 GB for gradients (and possibly extra for intermediate calculations).

That makes sense. THanks