I was hoping that since I’m using significant weight sharing, I will not get out-of-memory issues, and would really appreciate if people here can help me solve the matter.
What I’m trying to do is this:
Working with single sample minibatches (intentionally), and given an input tensor, for example of size [1,70,1024,1024], dimensions being: [sample, slice, row, column]
I would like to use a feature extractor per slice, perform global max pooling on each, and after I have all 70 feature vectors, do an additional step combining all of them (for example, with a final average pooling on all).
The problem is that at the 4th slice (while still in the forward pass) I’m getting OOM exception.
My code looks like this:
class LitterNet(nn.module):
def __init__(self):
self.feature_extractor = ... #here I initialize it to my feature extractor, currently an InceptionResnetV2 module, chopped before the end, so it's fully convolutional
def forward(x):
slices_num = x.shape[1]
slices_features= []
for i in trange(slices_num):
curr_x = x[0,i,...][None,None,...] #now it will be of shape [1,1,1024,1024]
y = self.feature_extractor(curr_x) # <<< getting OOM here at the 4th iteration of this loop
#doing global max pooling on it
curr_slice_features = F.max_pool2d(y, kernel_size=y.shape[2:])
y = None
slices_features.append(curr_slice_features)
stacked_slices_features = torch.stack(slices_features)
#here I want to perform some pooling on it, and then more layers, but I never reach this region anyway :/
I’m trying to understand if I’m missing something here. As the graph is being built (for backward pass later), since all weights are shared, I did not expect to get an OOM after the 4th slice. I imagined that if I pass 1-2 slices, I’ll be able to pass the rest.
Any ideas how to solve this?
am I missing here something about how autograd builds the graph?
All help will be highly appreciated