# Proper way to use recursive function in pytorch model forward function

Hi,

Good day.

I tried to implement a recursive function in pytorch but looks like eventually i’m facing out of memory error during the training process.

Might related to this, but looks like no answer yet:

In my case, each image is separated to n windows, and n windows may varied according to image size, hence i’m using a recursive function to reduce it until a consistent size is reached.

If anyone have a better idea on how to do this, please let me know.

Part of the code:

``````    self.in_c1 = 32 # 1st layer output channel
self.psize = (2,2) # pool size
self.psize2 = (1,1) # pool size

self.conv_norm_pool_4c = nn.Sequential(
nn.Conv2d(4, self.in_c1, kernel_size=(3,3), stride=(1,1),padding = (1,1), bias=False),
nn.BatchNorm2d(self.in_c1),
nn.ReLU(True),
nn.MaxPool2d(self.psize, self.psize))

self.conv_norm_pool_3c = nn.Sequential(
nn.Conv2d(3, self.in_c1, kernel_size=(3,3), stride=(1,1),padding = (1,1), bias=False),
nn.BatchNorm2d(self.in_c1),
nn.ReLU(True),
nn.MaxPool2d(self.psize, self.psize))

self.conv_norm_pool_2c = nn.Sequential(
nn.Conv2d(2, self.in_c1, kernel_size=(3,3), stride=(1,1),padding = (1,1), bias=False),
nn.BatchNorm2d(self.in_c1),
nn.ReLU(True),
nn.MaxPool2d(self.psize, self.psize))

self.conv_norm_pool_1c = nn.Sequential(
nn.Conv2d(1, self.in_c1, kernel_size=(3,3), stride=(1,1),padding = (1,1), bias=False),
nn.BatchNorm2d(self.in_c1),
nn.ReLU(True),
nn.MaxPool2d(self.psize, self.psize))

self.conv_norm_pool_32c256 = nn.Sequential(
nn.Conv2d(32, 256, kernel_size=(3,3), stride=(1,1),padding = (1,1), bias=False),
nn.BatchNorm2d(256),
nn.ReLU(True),
nn.MaxPool2d(self.psize, self.psize))

self.conv_norm_pool_64c128 = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=(3,3), stride=(1,1),padding = (1,1), bias=False),
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.MaxPool2d(self.psize, self.psize))

self.conv_norm_pool_64c256 = nn.Sequential(
nn.Conv2d(64, 256, kernel_size=(3,3), stride=(1,1),padding = (1,1), bias=False),
nn.BatchNorm2d(256),
nn.ReLU(True),
nn.MaxPool2d(self.psize, self.psize))

self.conv_norm_pool_128c512 = nn.Sequential(
nn.Conv2d(128, 512, kernel_size=(3,3), stride=(1,1),padding = (1,1), bias=False),
nn.BatchNorm2d(512),
nn.ReLU(True),
nn.MaxPool2d(self.psize, self.psize))

self.conv_norm_pool_256c512 = nn.Sequential(
nn.Conv2d(256, 512, kernel_size=(3,3), stride=(1,1),padding = (1,1), bias=False),
nn.BatchNorm2d(512),
nn.ReLU(True),
nn.MaxPool2d(self.psize, self.psize))

self.conv_norm_pool_1024c512 = nn.Sequential(
nn.Conv2d(1024, 512, kernel_size=(3,3), stride=(1,1),padding = (1,1), bias=False),
nn.BatchNorm2d(512),
nn.ReLU(True),
nn.MaxPool2d(self.psize2, self.psize2))

def forward(self, input):

ys = input[0][0].shape[1]
xs = input[0][0].shape[2]

outf_array = []

for inputwindows in input:

# size of window
length_w = len(inputwindows)

# if size of window = 1
if length_w == 1:
input1 = inputwindows[0].reshape(1,1,ys,xs)
out1 = [self.conv_norm_pool_1c(input1)]

# if size of window = 2
elif length_w == 2:
input1 = inputwindows[0:2].reshape(1,2,ys,xs)
out1 = [self.conv_norm_pool_2c(input1)]

elif length_w == 3:
input1 = inputwindows[0:3].reshape(1,3,ys,xs)
out1 = [self.conv_norm_pool_3c(input1)]

else: # if size window >= 4

# number window
nw = 4

window_remain = length_w%nw
n_image = int(np.floor(length_w/nw))

# initialize
result_4_windows = []

## 1st layer ##################################################
# scan each 4 window and apply convo_norm_pool

for sind in range(0,length_w-nw+1,int(nw/2)):
# save 4 windows into single instance
img_4_window = inputwindows[sind:sind+4].reshape(1,nw,ys,xs)
out1 = self.conv_norm_pool_4c(img_4_window)
result_4_windows.append(out1)

# some leftover window, make 4 windows from last 4 index
if window_remain:
# save last instance
img_4_window = inputwindows[-4:].reshape(1,nw,ys,xs)
out1 = self.conv_norm_pool_4c(img_4_window)
result_4_windows.append(out1)

## consecutive layer ##################################################

## recursive function #########################################
def layers_process_recursive(result_4_windows):

final_out = []

if len(result_4_windows)>1:

# iniitalize
result_2_windows = []
img_windows_pair = []

# loop each window
for ind, img_window in enumerate(result_4_windows):

img_windows_pair.append(img_window)

if len(img_windows_pair) == 2:
combined_window = torch.cat((img_windows_pair[0],img_windows_pair[1]),1)

#
if combined_window.shape[1] == 64:
out1 = self.conv_norm_pool_64c128(combined_window)
elif combined_window.shape[1] == 256:
out1 = self.conv_norm_pool_256c512(combined_window)
elif combined_window.shape[1] == 1024:
out1 = self.conv_norm_pool_1024c512(combined_window)
#

result_2_windows.append(out1)
img_windows_pair = []

# last single window, pair it with 2nd last window
if img_windows_pair:
img_windows_pair.append(result_4_windows[-2])
combined_window = torch.cat((img_windows_pair[0],img_windows_pair[1]),1)
#
if combined_window.shape[1] == 64:
out1 = self.conv_norm_pool_64c128(combined_window)
elif combined_window.shape[1] == 256:
out1 = self.conv_norm_pool_256c512(combined_window)
elif combined_window.shape[1] == 1024:
out1 = self.conv_norm_pool_1024c512(combined_window)
#
result_2_windows.append(out1)
img_windows_pair = []

# recursively reduce windows into single window
final_out = layers_process_recursive(result_2_windows)
else:
# if input is single window, take it as the output
final_out = result_4_windows

return final_out

###############################################################

# recursively process 4 windows and merge them into final 1 window
out1 = layers_process_recursive(result_4_windows)

out2 = out1[0]

# consistent 512 channel output

if out2.shape[1] == 32:
out3 = self.conv_norm_pool_32c256(out2)
outf = self.conv_norm_pool_256c512(out3)
elif out2.shape[1] == 64:
out3 = self.conv_norm_pool_64c256(out2)
outf = self.conv_norm_pool_256c512(out3)
elif out2.shape[1] == 128:
outf = self.conv_norm_pool_128c512(out2)
elif out2.shape[1] == 256:
outf = self.conv_norm_pool_256c512(out2)
elif out2.shape[1] == 512:
outf = out2

outf_array.append(outf)

# 5D into 4D
out_tensor = outf_array[0]
for outf in outf_array[1:]:
out_tensor = torch.cat((out_tensor,outf),0)

return out_tensor``````

Appreciate if anyone who done this before able to advise further. Thanks again.