Proper way to use recursive function in pytorch model forward function

Hi,

Good day.

I tried to implement a recursive function in pytorch but looks like eventually i’m facing out of memory error during the training process.

Might related to this, but looks like no answer yet:

In my case, each image is separated to n windows, and n windows may varied according to image size, hence i’m using a recursive function to reduce it until a consistent size is reached.

If anyone have a better idea on how to do this, please let me know.

Thanks in advance.

Part of the code:

    self.in_c1 = 32 # 1st layer output channel
    self.psize = (2,2) # pool size
    self.psize2 = (1,1) # pool size

    self.conv_norm_pool_4c = nn.Sequential(
        nn.Conv2d(4, self.in_c1, kernel_size=(3,3), stride=(1,1),padding = (1,1), bias=False),
        nn.BatchNorm2d(self.in_c1), 
        nn.ReLU(True),
        nn.MaxPool2d(self.psize, self.psize))
    
    self.conv_norm_pool_3c = nn.Sequential(
        nn.Conv2d(3, self.in_c1, kernel_size=(3,3), stride=(1,1),padding = (1,1), bias=False),
        nn.BatchNorm2d(self.in_c1), 
        nn.ReLU(True),
        nn.MaxPool2d(self.psize, self.psize))

    self.conv_norm_pool_2c = nn.Sequential(
        nn.Conv2d(2, self.in_c1, kernel_size=(3,3), stride=(1,1),padding = (1,1), bias=False),
        nn.BatchNorm2d(self.in_c1), 
        nn.ReLU(True),
        nn.MaxPool2d(self.psize, self.psize))

    self.conv_norm_pool_1c = nn.Sequential(
        nn.Conv2d(1, self.in_c1, kernel_size=(3,3), stride=(1,1),padding = (1,1), bias=False),
        nn.BatchNorm2d(self.in_c1), 
        nn.ReLU(True),
        nn.MaxPool2d(self.psize, self.psize))

    self.conv_norm_pool_32c256 = nn.Sequential(
        nn.Conv2d(32, 256, kernel_size=(3,3), stride=(1,1),padding = (1,1), bias=False),
        nn.BatchNorm2d(256), 
        nn.ReLU(True),
        nn.MaxPool2d(self.psize, self.psize))

    self.conv_norm_pool_64c128 = nn.Sequential(
        nn.Conv2d(64, 128, kernel_size=(3,3), stride=(1,1),padding = (1,1), bias=False),
        nn.BatchNorm2d(128), 
        nn.ReLU(True),
        nn.MaxPool2d(self.psize, self.psize))
    
    self.conv_norm_pool_64c256 = nn.Sequential(
        nn.Conv2d(64, 256, kernel_size=(3,3), stride=(1,1),padding = (1,1), bias=False),
        nn.BatchNorm2d(256), 
        nn.ReLU(True),
        nn.MaxPool2d(self.psize, self.psize))
    
    self.conv_norm_pool_128c512 = nn.Sequential(
        nn.Conv2d(128, 512, kernel_size=(3,3), stride=(1,1),padding = (1,1), bias=False),
        nn.BatchNorm2d(512), 
        nn.ReLU(True),
        nn.MaxPool2d(self.psize, self.psize))
    
    self.conv_norm_pool_256c512 = nn.Sequential(
        nn.Conv2d(256, 512, kernel_size=(3,3), stride=(1,1),padding = (1,1), bias=False),
        nn.BatchNorm2d(512), 
        nn.ReLU(True),
        nn.MaxPool2d(self.psize, self.psize))
    
    self.conv_norm_pool_1024c512 = nn.Sequential(
        nn.Conv2d(1024, 512, kernel_size=(3,3), stride=(1,1),padding = (1,1), bias=False),
        nn.BatchNorm2d(512), 
        nn.ReLU(True),
        nn.MaxPool2d(self.psize2, self.psize2))


def forward(self, input):
    
    ys = input[0][0].shape[1]
    xs = input[0][0].shape[2]
    
    outf_array = []
    
    for inputwindows in input:
        
        # size of window
        length_w = len(inputwindows)
        
        # if size of window = 1
        if length_w == 1:
            input1 = inputwindows[0].reshape(1,1,ys,xs)
            out1 = [self.conv_norm_pool_1c(input1)]
        
        # if size of window = 2
        elif length_w == 2:
            input1 = inputwindows[0:2].reshape(1,2,ys,xs)
            out1 = [self.conv_norm_pool_2c(input1)]
            
        elif length_w == 3:
            input1 = inputwindows[0:3].reshape(1,3,ys,xs)
            out1 = [self.conv_norm_pool_3c(input1)]
        
        else: # if size window >= 4
            
            # number window
            nw = 4 
            
            window_remain = length_w%nw
            n_image = int(np.floor(length_w/nw))
            
            # initialize
            result_4_windows = []
            
            ## 1st layer ##################################################
            # scan each 4 window and apply convo_norm_pool
            
            for sind in range(0,length_w-nw+1,int(nw/2)):
                # save 4 windows into single instance
                img_4_window = inputwindows[sind:sind+4].reshape(1,nw,ys,xs)
                out1 = self.conv_norm_pool_4c(img_4_window)
                result_4_windows.append(out1)
                
            # some leftover window, make 4 windows from last 4 index
            if window_remain: 
                # save last instance
                img_4_window = inputwindows[-4:].reshape(1,nw,ys,xs)
                out1 = self.conv_norm_pool_4c(img_4_window)
                result_4_windows.append(out1)
              
            ## consecutive layer ##################################################    
   
    
            ## recursive function #########################################
            def layers_process_recursive(result_4_windows):
                
                final_out = []
                
                if len(result_4_windows)>1:
    
                    # iniitalize
                    result_2_windows = []    
                    img_windows_pair = []
                    
                    # loop each window
                    for ind, img_window in enumerate(result_4_windows):
                        
                        img_windows_pair.append(img_window)
                        
                        if len(img_windows_pair) == 2: 
                            combined_window = torch.cat((img_windows_pair[0],img_windows_pair[1]),1)
                            
                            #
                            if combined_window.shape[1] == 64:
                                out1 = self.conv_norm_pool_64c128(combined_window)
                            elif combined_window.shape[1] == 256:
                                out1 = self.conv_norm_pool_256c512(combined_window)
                            elif combined_window.shape[1] == 1024:
                                out1 = self.conv_norm_pool_1024c512(combined_window)
                            #
                                
                            result_2_windows.append(out1)
                            img_windows_pair = []
                        
                    # last single window, pair it with 2nd last window
                    if img_windows_pair:
                        img_windows_pair.append(result_4_windows[-2])
                        combined_window = torch.cat((img_windows_pair[0],img_windows_pair[1]),1)
                        #
                        if combined_window.shape[1] == 64:
                            out1 = self.conv_norm_pool_64c128(combined_window)
                        elif combined_window.shape[1] == 256:
                            out1 = self.conv_norm_pool_256c512(combined_window)
                        elif combined_window.shape[1] == 1024:
                            out1 = self.conv_norm_pool_1024c512(combined_window)
                        #
                        result_2_windows.append(out1)
                        img_windows_pair = []
                        
                     # recursively reduce windows into single window  
                    final_out = layers_process_recursive(result_2_windows)
                else:
                    # if input is single window, take it as the output 
                    final_out = result_4_windows

                
                return final_out
                
            ###############################################################
            
            # recursively process 4 windows and merge them into final 1 window
            out1 = layers_process_recursive(result_4_windows)
            
    
        out2 = out1[0]
        
        # consistent 512 channel output
        
        if out2.shape[1] == 32:
            out3 = self.conv_norm_pool_32c256(out2)
            outf = self.conv_norm_pool_256c512(out3)
        elif out2.shape[1] == 64:
            out3 = self.conv_norm_pool_64c256(out2)
            outf = self.conv_norm_pool_256c512(out3)
        elif out2.shape[1] == 128:
            outf = self.conv_norm_pool_128c512(out2)
        elif out2.shape[1] == 256:
            outf = self.conv_norm_pool_256c512(out2) 
        elif out2.shape[1] == 512:
            outf = out2
    
        outf_array.append(outf)
    
    # 5D into 4D
    out_tensor = outf_array[0]
    for outf in outf_array[1:]:
        out_tensor = torch.cat((out_tensor,outf),0)

    return out_tensor

Appreciate if anyone who done this before able to advise further. Thanks again.