Hi!
For a masters course in deep learning we had to pick a paper and re-implement it and try some other experiments. The paper i choose is the zhang paper from 2016 regarding colorization.
I just built the network the way they have described it, without a loss function for now, and just tested it with some random numbers to see if everything was behaving as expected. But when I run the code below, and look at my activity monitor i see that the memory usage of the program just keeps increasing up to 50gb!
So my question is: is there anything super stupid in my implemententation of the network ? And if not is there any tools i can use to look where the memory leak is happening ?
edit: So after running some measurements with memory_profiler package I can see that the ram usage just keeps increasing to max but instead of crashing the program starts using swap space which goes up to over 30gb for the code provided below, if this helps anyone with some hints regarding this problem!
Because I dont really think that this is supposed to be how it works right?
Thanks in advance!
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import sys
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.BN_momentum = 1e-3 # Batch normalization momentum
self.color_bins = 313 # Number of color bins in ab colorspace
# nn.Conv2d(a,b,c); a input image channel, b output channels, cxc square convolution kernel
### Conv 1 ###
self.bw_conv1_1 = nn.Conv2d(1, 64, 3, stride=1, padding=1, dilation=1)
self.conv1_2 = nn.Conv2d(64, 64, 3, stride=2, padding=1, dilation=1)
self.conv1_2norm = nn.BatchNorm2d(64, momentum=self.BN_momentum)
### Conv 2 ###
self.conv2_1 = nn.Conv2d(64, 128, 3, stride=1, padding=1, dilation=1)
self.conv2_2 = nn.Conv2d(128, 128, 3, stride=2, padding=1, dilation=1)
self.conv2_2norm = nn.BatchNorm2d(128, momentum=self.BN_momentum)
### Conv 3 ###
self.conv3_1 = nn.Conv2d(128, 256, 3, stride=1, padding=1, dilation=1)
self.conv3_2 = nn.Conv2d(256, 256, 3, stride=1, padding=1, dilation=1)
self.conv3_3 = nn.Conv2d(256, 256, 3, stride=2, padding=1, dilation=1)
self.conv3_3norm = nn.BatchNorm2d(256, momentum=self.BN_momentum)
### Conv 4 ###
self.conv4_1 = nn.Conv2d(256, 512, 3, stride=1, padding=1, dilation=1)
self.conv4_2 = nn.Conv2d(512, 512, 3, stride=1, padding=1, dilation=1)
self.conv4_3 = nn.Conv2d(512, 512, 3, stride=1, padding=1, dilation=1)
self.conv4_3norm = nn.BatchNorm2d(512, momentum=self.BN_momentum)
### Conv 5 ###
self.conv5_1 = nn.Conv2d(512, 512, 3, stride=1, padding=2, dilation=2)
self.conv5_2 = nn.Conv2d(512, 512, 3, stride=1, padding=2, dilation=2)
self.conv5_3 = nn.Conv2d(512, 512, 3, stride=1, padding=2, dilation=2)
self.conv5_3norm = nn.BatchNorm2d(512, momentum=self.BN_momentum)
### Conv 6 ###
self.conv6_1 = nn.Conv2d(512, 512, 3, stride=1, padding=2, dilation=2)
self.conv6_2 = nn.Conv2d(512, 512, 3, stride=1, padding=2, dilation=2)
self.conv6_3 = nn.Conv2d(512, 512, 3, stride=1, padding=2, dilation=2)
self.conv6_3norm = nn.BatchNorm2d(512, momentum=self.BN_momentum)
### Conv 7 ###
self.conv7_1 = nn.Conv2d(512, 512, 3, stride=1, padding=1, dilation=1)
self.conv7_2 = nn.Conv2d(512, 512, 3, stride=1, padding=1, dilation=1)
self.conv7_3 = nn.Conv2d(512, 512, 3, stride=1, padding=1, dilation=1)
self.conv7_3norm = nn.BatchNorm2d(512, momentum=self.BN_momentum)
### Conv 8 ###
self.conv8_1 = nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1, dilation=1) # The dilation should be on the input
self.conv8_2 = nn.Conv2d(256, 256, 3, stride=1, padding=1, dilation=1)
self.conv8_3 = nn.Conv2d(256, 256, 3, stride=1, padding=1, dilation=1)
self.conv8_color_bins = nn.Conv2d(256, self.color_bins, 1, stride=1, padding=0, dilation=1)
### Softmax ###
self.softmax8 = nn.Softmax2d()
### Decoding ###
self.conv8_ab = nn.Conv2d(self.color_bins, 2, 1, stride=1, padding=0, dilation=1)
def forward(self, in_data):
### Conv 1 ###
print('in_data')
print(in_data.shape)
x = F.relu(self.bw_conv1_1(in_data))
x = F.relu(self.conv1_2(x))
x = self.conv1_2norm(x)
print('Conv 1')
print(x.shape)
### Conv 2 ###
x = F.relu(self.conv2_1(x))
x = F.relu(self.conv2_2(x))
x = self.conv2_2norm(x)
print('Conv 2')
print(x.shape)
### Conv 3 ###
x = F.relu(self.conv3_1(x))
x = F.relu(self.conv3_2(x))
x = F.relu(self.conv3_3(x))
x = self.conv3_3norm(x)
print('Conv 3')
print(x.shape)
### Conv 4 ###
x = F.relu(self.conv4_1(x))
x = F.relu(self.conv4_2(x))
x = F.relu(self.conv4_3(x))
x = self.conv4_3norm(x)
print('Conv 4')
print(x.shape)
### Conv 5 ###
x = F.relu(self.conv5_1(x))
x = F.relu(self.conv5_2(x))
x = F.relu(self.conv5_3(x))
x = self.conv5_3norm(x)
print('Conv 5')
print(x.shape)
### Conv 6 ###
x = F.relu(self.conv6_1(x))
x = F.relu(self.conv6_2(x))
x = F.relu(self.conv6_3(x))
x = self.conv6_3norm(x)
print('Conv 6')
print(x.shape)
### Conv 7 ###
x = F.relu(self.conv7_1(x))
x = F.relu(self.conv7_2(x))
x = F.relu(self.conv7_3(x))
x = self.conv7_3norm(x)
print('Conv 7')
print(x.shape)
### Conv 8 ###
x = F.relu(self.conv8_1(x))
print('8.1')
print(x.shape)
x = F.relu(self.conv8_2(x))
print(x.shape)
x = F.relu(self.conv8_3(x))
print(x.shape)
x = F.relu(self.conv8_color_bins(x))
print('Conv 8')
print(x.shape)
### Softmax ###
x = self.softmax8(x)
### Decoding ###
x = self.conv8_ab(x)
return x
'''
def num_flat_features(self, x):
size = x.size()[1:] # all dimensions except the batch dimension
num_features = 1
for s in size:
num_features *= s
return num_features
'''
if __name__ == '__main__':
net = Net()
print(len(list(net.parameters())))
# in_data = bilderna
in_data = torch.rand(100,1,256,256)
out_data = net(in_data)
print(out_data.shape)