Kernel dies on loss.backward()

Really cool to see MPS support in PyTorch! I have been testing it and it all works great until it comes trying to fine-tune a BERT model from HF. I have a simple training loop that looks like:

model = BertForSequenceClassification('bert-base-uncased')

model.train()

optim = torch.optim.Adam(model.parameters(), lr=5e-5)

model.to(device)

loop = tqdm(loader, leave=True)
for batch in loop:
    batch_mps = {
        'input_ids': batch['input_ids'].to(device),
        'attention_mask': batch['attention_mask'].to(device),
        'labels': batch['labels'].to(device)
    }
    optim.zero_grad()
    outputs = model(**batch_mps)
    loss = outputs[0]
    loss.backward()
    optim.step()

As soon as it hits the loss.backward() step the kernel dies. I tried minimizing the batch_size (although I figure with unified memory it can handle the same batch size as CPU?), but no luck.

It works as expected on CPU, am I missing something?

Can you please check your memory of GPU or CPU during the execution?

Thanks! The problem was simply that the batches, despite being incredibly small are still too large for my first gen M1. It begins to manage with a batch_size of 1. Thanks for your help!

@albanD could you modify the backend so that it fails gracefully when a MTLBuffer is initialized to nil because the computer ran out of memory?

Hi,

There is already a check there for the MTLBuffer: pytorch/MPSAllocator.mm at 4428218945e797cfc71a93dbb2d165535ea5a85b · pytorch/pytorch · GitHub

But the problem with shared memory is that any CPU allocation might start failing as well. And these we don’t always control and can lead to hard crash.

Tracing the code you cited, I saw something interesting. It’s not really relevant to this thread. I mentioned it in MPS device appears much slower than CPU on M1 Mac Pro · Issue #77799 · pytorch/pytorch · GitHub, but let’s keep discussion on this forums thread for now.

It seems that you indeed use heap-backed memory, something I thought of myself to allow for zero-cost allocation: pytorch/MPSAllocator.h at 09be44de7b56495bcb5ad1d47376200cbb853097 · pytorch/pytorch · GitHub. Could you go a little more into detail for how you came up with that allocation heuristic? Is there any parallel in the CUDA backend? Did you try an exponentially increasing idea such as the one described in Sharing ideas about our work. · Issue #1 · AnarchoSystems/DeepSwift · GitHub? It seems that you tried to make a heap be the same size as the buffer. That would make sense if you repeatedly recreate tensors of the same size, so you have heaps hanging around from previous allocations to reuse. But how do you know when enough pre-existing heaps is enough and start deleting previously allocated heaps before you run out of memory?

I haven’t tested my theory yet. If I could reuse your algorithm and the time you spent investigating this performance problem, that would be a big help for my personal ML project. I would make a documentation comment giving credit to PyTorch for coming up with the idea first.

Hi,

The allocator that we use for MPS is based on the CUDA Caching allocator that we already had.
It is not as fully featured yet but you can find some high level doc at CUDA semantics — PyTorch master documentation and more detailed comments in the code itself pytorch/CUDACachingAllocator.cpp at e2eb7a1edccfe9788edc50b6e8cddd62c2afff7a · pytorch/pytorch · GitHub

Note that this is a little bit copy-pasted right now for the MPS side and it will be refactored for the two to be closer once the MPS version is stable.

1 Like