Summary: With a ~100mb model and a ~400mb batch of training data, model(x)
causes an OOM despite having 16 GB of memory available.
I’ve been playing around with the Recursion Pharmaceuticals competition over on Kaggle, and I’ve noticed bizarre spikes in memory usage when I call models. I’ve tried to create a minimal example here. All of the code is present at that link, but here’s a summary of what I’m doing:
The data is 512x512 images with 6 channels. I’m using a pretty standard data loader to load them; the code is contained in ImagesDS
in cell 3. (The images should be normalized, but that doesn’t seem relevant here.)
The model is a Resnet50 with pretrained weights. However, since I have 6 channels and 1108 outputs, I replace the first and last layer. (I’ve seen the same error with other different base models like Densenet.):
model = torchvision.models.resnext50_32x4d(pretrained=True)
model.conv1 = nn.Conv2d(6, 64, 7, 2, 3)
model.fc = nn.Linear(2048, classes, bias=True)
model.to(device)
Finally, I’m getting the out of memory error on the first pass through the training loop:
epochs = 10
tlen = len(loader)
for epoch in range(epochs):
tloss = 0
acc = np.zeros(1)
for x, y in loader:
print(f'Memory allocated after tensor load to cpu: {torch.cuda.memory_allocated() / 10 ** 6} MB') # Gives ~100mb
x = x.to(device)
print(f'Memory allocated after tensor load to gpu: {torch.cuda.memory_allocated() / 10 ** 6} MB') # Gives ~500mb
optimizer.zero_grad()
# Everything explodes when we call the model on the input.
output = model(x)
# More training code that is never reached....
With a batch size of 64 (~400mb of input data), the loop causes an OOM. With a batch size of 16 (~100mb), memory usage never goes above ~500mb.
I can see a few possibilities here, but I’m unsure what’s most likely:
- I’m doing something wrong when I load the data. Maybe the input tensors need to have requires_grad=False explicitly set on them or something
- There’s some kind of memory allocation bug in Pytorch that I’m seeing here
- There’s something about Kaggle’s GPU set up that causes the memory error.
Any ideas? The full code is at the link above, and you can easily clone it if you have a Kaggle account.