Using the following code:
import torch
import numpy as np
from vgg_unet_aspp_detection import UNetVgg
import os
import psutil
import gc
process = psutil.Process(os.getpid())
device_str = "cpu"
device = torch.device(device_str)
model = UNetVgg(4, device)
model = model.eval()
model = model.to(device)
for i in range(1000):
x = np.random.randint(500, 900)
y = np.random.randint(500, 1000)
img = np.random.rand(x, y, 3)
img_pt = img.astype(np.float32) / 255.0
img_pt = img_pt.transpose(2,0,1)
img_pt = torch.from_numpy(img_pt[None, ...]).to(device)
with torch.no_grad():
output, _ = model(img_pt)
print(output[0, 0, 0, 0])
gc.collect()
print('Loop %d - Memory: %f' % (i, process.memory_percent()))
I have the following output:
tensor(-0.0305)
Loop 0 - Memory: 3.526018
tensor(-0.0306)
Loop 1 - Memory: 4.146981
tensor(-0.0306)
Loop 2 - Memory: 4.146662
tensor(-0.0305)
Loop 3 - Memory: 4.390995
tensor(-0.0306)
Loop 4 - Memory: 4.518435
tensor(-0.0304)
Loop 5 - Memory: 4.517454
tensor(-0.0306)
Loop 6 - Memory: 4.534013
tensor(-0.0306)
Loop 7 - Memory: 4.279794
tensor(-0.0305)
Loop 8 - Memory: 4.967409
tensor(-0.0305)
Loop 9 - Memory: 4.966697
tensor(-0.0305)
Loop 10 - Memory: 5.320857
tensor(-0.0306)
Loop 11 - Memory: 5.320440
tensor(-0.0305)
Loop 12 - Memory: 5.326524
tensor(-0.0304)
Loop 13 - Memory: 5.353901
Why does the Memory changes in this accumulative way? I’ve tried with Pytorch 1.0.1 and Pytorch 1.1, always using CPU. It seems like there is an upper bound, but still, I find this behaviour puzzling and it eats up a reasonable amount of memory.
The model is this one:
import torch
import torchvision
import numpy as np
from torch import nn
import torch.nn.init as init
class ASPPModule(nn.Module):
def __init__(self, features, inner_features=256, out_features=512, dilations=(3, 5, 8)):
super(ASPPModule, self).__init__()
self.conv1 = nn.Sequential(nn.AdaptiveAvgPool2d((1,1)),
nn.Conv2d(features, inner_features, kernel_size=1, padding=0, dilation=1, bias=False))
self.conv2 = nn.Conv2d(features, inner_features, kernel_size=3, padding=1, dilation=1, bias=False)
self.conv3 = nn.Conv2d(features, inner_features, kernel_size=3, padding=dilations[0], dilation=dilations[0], bias=False)
self.conv4 = nn.Conv2d(features, inner_features, kernel_size=3, padding=dilations[1], dilation=dilations[1], bias=False)
self.conv5 = nn.Conv2d(features, inner_features, kernel_size=3, padding=dilations[2], dilation=dilations[2], bias=False)
self.bottleneck = nn.Sequential(
nn.Conv2d(inner_features * 5, out_features, kernel_size=1, padding=0, dilation=1, bias=False),
torch.nn.ReLU(True),
nn.Conv2d(out_features, out_features, kernel_size=1, padding=0, dilation=1),
torch.nn.ReLU(True)
)
def forward(self, x):
_, _, h, w = x.size()
feat1 = torch.nn.functional.interpolate(self.conv1(x), size=(h, w), mode='bilinear', align_corners=True)
feat2 = self.conv2(x)
feat3 = self.conv3(x)
feat4 = self.conv4(x)
feat5 = self.conv5(x)
out = torch.cat((feat1, feat2, feat3, feat4, feat5), 1)
bottle = self.bottleneck(out)
return bottle
class UNetVgg(torch.nn.Module):
"""
Combines UNet (VGG based) with the ASPP module for segmentation.
"""
def __init__(self, nClasses, device):
super(UNetVgg, self).__init__()
vgg16pre = torchvision.models.vgg16(pretrained=True)
self.vgg0 = torch.nn.Sequential(*list(vgg16pre.features.children())[:4])
self.vgg1 = torch.nn.Sequential(*list(vgg16pre.features.children())[4:9])
self.vgg2 = torch.nn.Sequential(*list(vgg16pre.features.children())[9:16])
self.vgg3 = torch.nn.Sequential(*list(vgg16pre.features.children())[16:23])
self.vgg4 = torch.nn.Sequential(*list(vgg16pre.features.children())[23:30])
self.bottom = torch.nn.Sequential(
torch.nn.MaxPool2d(2, 2),
ASPPModule(512)
)
self.aux_path = torch.nn.Sequential(
torch.nn.Conv2d(128, 64, kernel_size=(3,3), stride=1, padding=(1, 1)),
torch.nn.ReLU(True),
torch.nn.Conv2d(64, nClasses, kernel_size=1, stride=1, padding=0),
)
self.smooth0 = torch.nn.Sequential(
torch.nn.Conv2d(128, 64, kernel_size=(3,3), stride=1, padding=(1, 1)),
torch.nn.ReLU(True),
torch.nn.Conv2d(64, 64, kernel_size=(3,3), stride=1, padding=(1, 1)),
torch.nn.ReLU(True)
)
self.smooth1 = torch.nn.Sequential(
torch.nn.Conv2d(384, 64, kernel_size=(3,3), stride=1, padding=(1, 1)),
torch.nn.ReLU(True),
torch.nn.Conv2d(64, 64, kernel_size=(3,3), stride=1, padding=(1, 1)),
torch.nn.ReLU(True)
)
self.smooth2 = torch.nn.Sequential(
torch.nn.Conv2d(512, 128, kernel_size=(3,3), stride=1, padding=(1, 1)),
torch.nn.ReLU(True),
torch.nn.Conv2d(128, 128, kernel_size=(3,3), stride=1, padding=(1, 1)),
torch.nn.ReLU(True)
)
self.smooth3 = torch.nn.Sequential(
torch.nn.Conv2d(768, 256, kernel_size=(3,3), stride=1, padding=(1, 1)),
torch.nn.ReLU(True),
torch.nn.Conv2d(256, 256, kernel_size=(3,3), stride=1, padding=(1, 1)),
torch.nn.ReLU(True)
)
self.smooth4 = torch.nn.Sequential(
torch.nn.Conv2d(1024, 256, kernel_size=(3,3), stride=1, padding=(1, 1)),
torch.nn.ReLU(True),
torch.nn.Conv2d(256, 256, kernel_size=(3,3), stride=1, padding=(1, 1)),
torch.nn.ReLU(True)
)
self.pass0 = torch.nn.Sequential(
torch.nn.Conv2d(64, 64, kernel_size=(1,1), stride=1, padding=(0, 0)),
torch.nn.ReLU(True)
)
self.pass1 = torch.nn.Sequential(
torch.nn.Conv2d(128, 128, kernel_size=(1,1), stride=1, padding=(0, 0)),
torch.nn.ReLU(True)
)
self.bottom_up = torch.nn.Sequential(
torch.nn.Conv2d(512, 128, kernel_size=(1,1), stride=1, padding=(0, 0)),
torch.nn.ReLU(True)
)
self.final = torch.nn.Conv2d(64, nClasses, kernel_size=1, stride=1, padding=0)
def forward(self, x):
x = self.vgg0(x)
feat0 = self.pass0(x)
x = self.vgg1(x)
feat1 = self.pass1(x)
feat2 = self.vgg2(x)
feat3 = self.vgg3(feat2)
feat4 = self.vgg4(feat3)
feat5 = self.bottom(feat4)
btp = self.bottom_up(feat5)
_,_,H,W = feat4.size()
up4 = torch.nn.functional.interpolate(feat5, size=(H,W), mode='bilinear', align_corners=True)
concat4 = torch.cat([feat4, up4], 1)
end4 = self.smooth4(concat4)
_,_,H,W = feat3.size()
up3 = torch.nn.functional.interpolate(end4, size=(H,W), mode='bilinear', align_corners=True)
concat3 = torch.cat([feat3, up3], 1)
end3 = self.smooth3(concat3)
_,_,H,W = feat2.size()
up2 = torch.nn.functional.interpolate(end3, size=(H,W), mode='bilinear', align_corners=True)
concat2 = torch.cat([feat2, up2], 1)
end2 = self.smooth2(concat2)
aux_out = self.aux_path(end2)
_,_,H,W = feat1.size()
up1 = torch.nn.functional.interpolate(end2, size=(H,W), mode='bilinear', align_corners=True)
bottom_up = torch.nn.functional.interpolate(btp, size=(H,W), mode='bilinear', align_corners=True)
concat1 = torch.cat([feat1, up1, bottom_up], 1)
end1 = self.smooth1(concat1)
_,_,H,W = feat0.size()
up0 = torch.nn.functional.interpolate(end1, size=(H,W), mode='bilinear', align_corners=True)
concat0 = torch.cat([feat0, up0], 1)
end0 = self.smooth0(concat0)
aux_out = torch.nn.functional.interpolate(aux_out, size=(H,W), mode='bilinear', align_corners=True)
return self.final(end0), aux_out