Multi siamese network - getting out of memory despite heavy weight sharing

I was hoping that since I’m using significant weight sharing, I will not get out-of-memory issues, and would really appreciate if people here can help me solve the matter.

What I’m trying to do is this:

Working with single sample minibatches (intentionally), and given an input tensor, for example of size [1,70,1024,1024], dimensions being: [sample, slice, row, column]
I would like to use a feature extractor per slice, perform global max pooling on each, and after I have all 70 feature vectors, do an additional step combining all of them (for example, with a final average pooling on all).

The problem is that at the 4th slice (while still in the forward pass) I’m getting OOM exception.

My code looks like this:

class LitterNet(nn.module):
    def __init__(self):
        self.feature_extractor = ... #here I initialize it to my feature extractor, currently an InceptionResnetV2 module, chopped before the end, so it's fully convolutional
    def forward(x):
        slices_num = x.shape[1]
        slices_features= []
        for i in trange(slices_num):
            curr_x = x[0,i,...][None,None,...]   #now it will be of shape [1,1,1024,1024]
            y = self.feature_extractor(curr_x)  #  <<< getting OOM here at the 4th iteration of this loop
            #doing global max pooling on it
            curr_slice_features = F.max_pool2d(y, kernel_size=y.shape[2:]) 
            y = None

        stacked_slices_features = torch.stack(slices_features)
        #here I want to perform some pooling on it, and then more layers, but I never reach this region anyway :/

I’m trying to understand if I’m missing something here. As the graph is being built (for backward pass later), since all weights are shared, I did not expect to get an OOM after the 4th slice. I imagined that if I pass 1-2 slices, I’ll be able to pass the rest.
Any ideas how to solve this?
am I missing here something about how autograd builds the graph?

All help will be highly appreciated :slight_smile:

How did you manage to pass the input with dimension [1, 1, 1024, 1024] into the InceptionResnet?
Anyway, as far as I understand your workflow, you are passing each “slice” into the feature_extractor in the for-loop.
Could you try to pass it as a batch, i.e. change the slice dimension with the batch dimension: x = x.view(70, 1, 1024, 1024) and try it without the loop.
Probably it won’t solve the OOM error, but it’s worth a try.

1 Like

First of all - thanks for the reply! :slight_smile:

Are you asking how I managed to pass such shape to InceptionResnetV2 due to it not being “rgb”-ish? It’s a modified version, since I’m working with single color channel type of data.
I’m training the weights from scratch so I only had to do a simple modification of the InceptionResnetV2 code to achieve it.

Passing it in a batch as you described does cause an OOM immediately

I’m trying to understand - while the graph is being built, isn’t the fact that all of the weights are shared help in the reduction/reuse of memory ?
maybe there is something explicit I need to do in order to “help” autograd understand that it can reuse existing buffers?

Ah ok, I thought you’ve added another Conv layer in front of it. :wink:

The weights are shared, but each sample produces an activation, which has to be stored for backward.
Or what do you mean by weight sharing?

I guess you can’t pass the input as a few smaller batches, since you need to stack it and then perform some other operations on it, right?

Alternatively, you could try out model checkpoint.
It’s quite a new method, which allows you to trade memory for compute performance.
The article describes it pretty clearly.

You would have to wait a few days for the 0.4 release or just build PyTorch from source. It should be quite easy to build it, but let me know if you encounter any errors.

1 Like

You’re right - I can’t pass few smaller batches, because I need to stack it and have a single graph that considers all of it (so the loss function will consider it all at once)

Thanks for the “model checkpoint” tip - will definitely give it a try!

Edit: moved suggestion below to a separate thread.
wondering - would it make sense to have a mechanism in pytorch, that moves tensors from GPU memory to system ram and back when needed ?

I understand that there will be overhead of the memory allocation and transfer itself, but is it really that bad up to a point that it makes such mechanism worthless?

Came back here just to say thanks for tip about model checkpointing!
so far, it seems to work well :slight_smile:

Nice work! I’m glad it worked out :wink:

1 Like