Why does torch graph consume a large memory footprint. I get a spike of 7 GB (gpu memory) by instantiating the following model:
class Model(torch.nn.Module):
# Define model
def __init__(self, args, mean, std):
super(Model, self).__init__()
self.numCells = args.numCells
self.mean = Variable(mean, requires_grad = False)
self.std = Variable(std, requires_grad = False)
self.conv1 = nn.Conv2d(3, 64, 5, padding = 2)
self.conv2 = nn.Conv2d(64, 128, 5, padding=2)
self.conv3 = nn.Conv2d(128, 128, 3, padding = 1)
self.conv4 = nn.Conv2d(128, 128, 3, padding =1)
self.conv5 = nn.Conv2d(128, 256, 3, padding = 1)
self.conv6 = nn.Conv2d(256, 256, 3, padding = 1)
self.conv7 = nn.Conv2d(256, 512, 3, padding = 1)
self.conv8 = nn.Conv2d(512, 512, 3, padding = 1)
self.conv9 = nn.Conv2d(512, 512, 5, padding = 2)
self.fc = nn.Linear(32 * 32 * 512, self.numCells * self.numCells * 7)