Hi,
Good day.
I tried to implement a recursive function in pytorch but looks like eventually i’m facing out of memory error during the training process.
Might related to this, but looks like no answer yet:
In my case, each image is separated to n windows, and n windows may varied according to image size, hence i’m using a recursive function to reduce it until a consistent size is reached.
If anyone have a better idea on how to do this, please let me know.
Thanks in advance.
Part of the code:
self.in_c1 = 32 # 1st layer output channel
self.psize = (2,2) # pool size
self.psize2 = (1,1) # pool size
self.conv_norm_pool_4c = nn.Sequential(
nn.Conv2d(4, self.in_c1, kernel_size=(3,3), stride=(1,1),padding = (1,1), bias=False),
nn.BatchNorm2d(self.in_c1),
nn.ReLU(True),
nn.MaxPool2d(self.psize, self.psize))
self.conv_norm_pool_3c = nn.Sequential(
nn.Conv2d(3, self.in_c1, kernel_size=(3,3), stride=(1,1),padding = (1,1), bias=False),
nn.BatchNorm2d(self.in_c1),
nn.ReLU(True),
nn.MaxPool2d(self.psize, self.psize))
self.conv_norm_pool_2c = nn.Sequential(
nn.Conv2d(2, self.in_c1, kernel_size=(3,3), stride=(1,1),padding = (1,1), bias=False),
nn.BatchNorm2d(self.in_c1),
nn.ReLU(True),
nn.MaxPool2d(self.psize, self.psize))
self.conv_norm_pool_1c = nn.Sequential(
nn.Conv2d(1, self.in_c1, kernel_size=(3,3), stride=(1,1),padding = (1,1), bias=False),
nn.BatchNorm2d(self.in_c1),
nn.ReLU(True),
nn.MaxPool2d(self.psize, self.psize))
self.conv_norm_pool_32c256 = nn.Sequential(
nn.Conv2d(32, 256, kernel_size=(3,3), stride=(1,1),padding = (1,1), bias=False),
nn.BatchNorm2d(256),
nn.ReLU(True),
nn.MaxPool2d(self.psize, self.psize))
self.conv_norm_pool_64c128 = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=(3,3), stride=(1,1),padding = (1,1), bias=False),
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.MaxPool2d(self.psize, self.psize))
self.conv_norm_pool_64c256 = nn.Sequential(
nn.Conv2d(64, 256, kernel_size=(3,3), stride=(1,1),padding = (1,1), bias=False),
nn.BatchNorm2d(256),
nn.ReLU(True),
nn.MaxPool2d(self.psize, self.psize))
self.conv_norm_pool_128c512 = nn.Sequential(
nn.Conv2d(128, 512, kernel_size=(3,3), stride=(1,1),padding = (1,1), bias=False),
nn.BatchNorm2d(512),
nn.ReLU(True),
nn.MaxPool2d(self.psize, self.psize))
self.conv_norm_pool_256c512 = nn.Sequential(
nn.Conv2d(256, 512, kernel_size=(3,3), stride=(1,1),padding = (1,1), bias=False),
nn.BatchNorm2d(512),
nn.ReLU(True),
nn.MaxPool2d(self.psize, self.psize))
self.conv_norm_pool_1024c512 = nn.Sequential(
nn.Conv2d(1024, 512, kernel_size=(3,3), stride=(1,1),padding = (1,1), bias=False),
nn.BatchNorm2d(512),
nn.ReLU(True),
nn.MaxPool2d(self.psize2, self.psize2))
def forward(self, input):
ys = input[0][0].shape[1]
xs = input[0][0].shape[2]
outf_array = []
for inputwindows in input:
# size of window
length_w = len(inputwindows)
# if size of window = 1
if length_w == 1:
input1 = inputwindows[0].reshape(1,1,ys,xs)
out1 = [self.conv_norm_pool_1c(input1)]
# if size of window = 2
elif length_w == 2:
input1 = inputwindows[0:2].reshape(1,2,ys,xs)
out1 = [self.conv_norm_pool_2c(input1)]
elif length_w == 3:
input1 = inputwindows[0:3].reshape(1,3,ys,xs)
out1 = [self.conv_norm_pool_3c(input1)]
else: # if size window >= 4
# number window
nw = 4
window_remain = length_w%nw
n_image = int(np.floor(length_w/nw))
# initialize
result_4_windows = []
## 1st layer ##################################################
# scan each 4 window and apply convo_norm_pool
for sind in range(0,length_w-nw+1,int(nw/2)):
# save 4 windows into single instance
img_4_window = inputwindows[sind:sind+4].reshape(1,nw,ys,xs)
out1 = self.conv_norm_pool_4c(img_4_window)
result_4_windows.append(out1)
# some leftover window, make 4 windows from last 4 index
if window_remain:
# save last instance
img_4_window = inputwindows[-4:].reshape(1,nw,ys,xs)
out1 = self.conv_norm_pool_4c(img_4_window)
result_4_windows.append(out1)
## consecutive layer ##################################################
## recursive function #########################################
def layers_process_recursive(result_4_windows):
final_out = []
if len(result_4_windows)>1:
# iniitalize
result_2_windows = []
img_windows_pair = []
# loop each window
for ind, img_window in enumerate(result_4_windows):
img_windows_pair.append(img_window)
if len(img_windows_pair) == 2:
combined_window = torch.cat((img_windows_pair[0],img_windows_pair[1]),1)
#
if combined_window.shape[1] == 64:
out1 = self.conv_norm_pool_64c128(combined_window)
elif combined_window.shape[1] == 256:
out1 = self.conv_norm_pool_256c512(combined_window)
elif combined_window.shape[1] == 1024:
out1 = self.conv_norm_pool_1024c512(combined_window)
#
result_2_windows.append(out1)
img_windows_pair = []
# last single window, pair it with 2nd last window
if img_windows_pair:
img_windows_pair.append(result_4_windows[-2])
combined_window = torch.cat((img_windows_pair[0],img_windows_pair[1]),1)
#
if combined_window.shape[1] == 64:
out1 = self.conv_norm_pool_64c128(combined_window)
elif combined_window.shape[1] == 256:
out1 = self.conv_norm_pool_256c512(combined_window)
elif combined_window.shape[1] == 1024:
out1 = self.conv_norm_pool_1024c512(combined_window)
#
result_2_windows.append(out1)
img_windows_pair = []
# recursively reduce windows into single window
final_out = layers_process_recursive(result_2_windows)
else:
# if input is single window, take it as the output
final_out = result_4_windows
return final_out
###############################################################
# recursively process 4 windows and merge them into final 1 window
out1 = layers_process_recursive(result_4_windows)
out2 = out1[0]
# consistent 512 channel output
if out2.shape[1] == 32:
out3 = self.conv_norm_pool_32c256(out2)
outf = self.conv_norm_pool_256c512(out3)
elif out2.shape[1] == 64:
out3 = self.conv_norm_pool_64c256(out2)
outf = self.conv_norm_pool_256c512(out3)
elif out2.shape[1] == 128:
outf = self.conv_norm_pool_128c512(out2)
elif out2.shape[1] == 256:
outf = self.conv_norm_pool_256c512(out2)
elif out2.shape[1] == 512:
outf = out2
outf_array.append(outf)
# 5D into 4D
out_tensor = outf_array[0]
for outf in outf_array[1:]:
out_tensor = torch.cat((out_tensor,outf),0)
return out_tensor