Multi siamese network - getting out of memory despite heavy weight sharing

I was hoping that since I’m using significant weight sharing, I will not get out-of-memory issues, and would really appreciate if people here can help me solve the matter.

What I’m trying to do is this:

Working with single sample minibatches (intentionally), and given an input tensor, for example of size [1,70,1024,1024], dimensions being: [sample, slice, row, column]
I would like to use a feature extractor per slice, perform global max pooling on each, and after I have all 70 feature vectors, do an additional step combining all of them (for example, with a final average pooling on all).

The problem is that at the 4th slice (while still in the forward pass) I’m getting OOM exception.

My code looks like this:

class LitterNet(nn.module):
    def __init__(self):
        self.feature_extractor = ... #here I initialize it to my feature extractor, currently an InceptionResnetV2 module, chopped before the end, so it's fully convolutional
    def forward(x):
        slices_num = x.shape[1]
        slices_features= []
        for i in trange(slices_num):
            curr_x = x[0,i,...][None,None,...]   #now it will be of shape [1,1,1024,1024]
            y = self.feature_extractor(curr_x)  #  <<< getting OOM here at the 4th iteration of this loop
            #doing global max pooling on it
            curr_slice_features = F.max_pool2d(y, kernel_size=y.shape[2:]) 
            y = None

        stacked_slices_features = torch.stack(slices_features)
        #here I want to perform some pooling on it, and then more layers, but I never reach this region anyway :/

I’m trying to understand if I’m missing something here. As the graph is being built (for backward pass later), since all weights are shared, I did not expect to get an OOM after the 4th slice. I imagined that if I pass 1-2 slices, I’ll be able to pass the rest.
Any ideas how to solve this?
am I missing here something about how autograd builds the graph?

All help will be highly appreciated :slight_smile:

How did you manage to pass the input with dimension [1, 1, 1024, 1024] into the InceptionResnet?
Anyway, as far as I understand your workflow, you are passing each “slice” into the feature_extractor in the for-loop.
Could you try to pass it as a batch, i.e. change the slice dimension with the batch dimension: x = x.view(70, 1, 1024, 1024) and try it without the loop.
Probably it won’t solve the OOM error, but it’s worth a try.

First of all - thanks for the reply! :slight_smile:

Are you asking how I managed to pass such shape to InceptionResnetV2 due to it not being “rgb”-ish? It’s a modified version, since I’m working with single color channel type of data.
I’m training the weights from scratch so I only had to do a simple modification of the InceptionResnetV2 code to achieve it.

Passing it in a batch as you described does cause an OOM immediately

I’m trying to understand - while the graph is being built, isn’t the fact that all of the weights are shared help in the reduction/reuse of memory ?
maybe there is something explicit I need to do in order to “help” autograd understand that it can reuse existing buffers?

Ah ok, I thought you’ve added another Conv layer in front of it. :wink:

The weights are shared, but each sample produces an activation, which has to be stored for backward.
Or what do you mean by weight sharing?

I guess you can’t pass the input as a few smaller batches, since you need to stack it and then perform some other operations on it, right?

Alternatively, you could try out model checkpoint.
It’s quite a new method, which allows you to trade memory for compute performance.
The article describes it pretty clearly.

You would have to wait a few days for the 0.4 release or just build PyTorch from source. It should be quite easy to build it, but let me know if you encounter any errors.

You’re right - I can’t pass few smaller batches, because I need to stack it and have a single graph that considers all of it (so the loss function will consider it all at once)

Thanks for the “model checkpoint” tip - will definitely give it a try!

wondering - would it make sense to have a mechanism in pytorch, that moves tensors from GPU memory to system ram and back when needed ?

I understand that there will be overhead of the memory allocation and transfer itself, but is it really that bad up to a point that it makes such mechanism worthless?

Came back here just to say thanks for tip about model checkpointing!
so far, it seems to work well :slight_smile:

Nice work! I’m glad it worked out :wink:

