Out of memory // for loops forward function

Hi,
I’m implementing a network and im getting out of memory.

I’m testing the net building the model and doing a manual forward pass over a batch to check output dimensions. However this is a “tricky” net which has Temporal pooling and uses several samples for training…

I think the problem comes from Video Analysis network. It’s a ResNet fed with 3 images. The output of these 3 images are concatenated and temporal max pooling applied to feed the next subnet.
The matter is the only way I found to do lot of stuff in pytorch is using for loops inside the forward function.
Since implemented functions expect 4D tensor as input and my data has more dimensions to deal with all the atypical facts of this net. I iterate using for loops and stacking the resultant output to deal with all the dimensions, being possible this way to process the 3 images and concatanate their features.
I suspect that for loops are hugely increasing the memory usage.
Model occupies around 2 Gb. 1 sample (raw data), occupies 4 Mb. Processing a batch of 10 samples uses more than 32 Gb . Totally impossible.

How to fix this or what is the proper way to do it?
Notice I’m not backprop yet, just doing output = model(batch)

    def forward(self,visual_input,audio_input):
        """
        VIDEO SUBNET----------------------------------------------------------
        """
        dims = visual_input.size()
        if len(dims) !=6:
            raise ValueError('Visual features have wrong dimension. Required: [batch_size,n_tower,n_images,channels,H,W]')
        self.batch_cat = []
        #dims [batch_size,n_tower,n_images,channels,H,W]
        for i in range(dims[0]):
            self.tower_cat = []
            for j in range(dims[1]):
                self.tower_cat.append(self.drn_model(visual_input[i,j,:,:,:,:]))
            self.batch_cat.append(torch.stack(self.tower_cat))
        ASN_input = torch.stack(self.batch_cat)
        dims = ASN_input.size()
        #dims [batch_size,n_tower,n_images,channels,H,W]
        self.batch_cat = []
        if self.training:
            for i in range(dims[0]):
                self.tower_cat = []
                for j in range(dims[1]):
                    self.tower_cat.append(data.GlobalMaxPooling2d(data.TemporalPooling(ASN_input[i,j,:,:,:,:])))
                self.batch_cat.append(torch.stack(self.tower_cat))
        #dims [batch_size,n_tower,channels,1,1]            
        else:
            for i in range(dims[0]):
                self.tower_cat = []
                for j in range(dims[1]):
                    self.tower_cat.append(data.TemporalPooling(ASN_input[i,j,:,:,:,:]))
                self.batch_cat.append(torch.stack(self.tower_cat))
        #dims [batch_size,n_tower,channels,H,W]
        
        ASN_input = torch.stack(self.batch_cat)    
        """
        AUDIO SUBNET----------------------------------------------------------
        """
        K_spectrograms = self.unet_model(audio_input)
        
        return ASN_input, K_spectrograms
        

hi, have you resolved this problem

I think so hahahah. I was wrongly coding it.