How to handle a memory expensive step in the model to support higher batch size?

I have a step in my model which is very expensive as it does operation over the entire vocabulary. Let’s say the step is as follows

output =self.vocab_process(input)

input is of dimension bsz x seq_len x d
output is of dimension bsz x seq_len x k

I found that batch size of 128 works well for this step but anything more than that would OOM. I am thinking of ways to make the model evaluation faster. Hence to support I am wondering if I can set the eval batch size to 1024 and in each iteration instead of calling vocab_process once, call it 8 times.

Something like

outputs = []
for i in range(8):
    output =self.vocab_process(input[i*128:(i+1)*128])
    outputs.append(output)

model_out = torch.stack(outputs)

Also one more step that im considering is inside the implementation of vocab_process, call del tensor_name to free up memory once i dont need a tensor anymore. How does calling del compare to making everything inplace?

Are these two techniques reasonable? Are there better/cleaner ways to do this?

Thank you.

If intermediate tensors in vocab_process need the device memory and will thus raise the OOM issue, your loop approach could work (assuming that vocab_process returns the right output for sub-tensors).

Since Python uses function scoping you wouldn’t have to delete the intermediate tensors, as they will be freed once you exit vocab_process. If you are running into the OOM before exiting the method, the del methods might still be useful.

In-place operations won’t allocate new memory, but will instead manipulate the tensor directly.

Thanks for the reply! Are you aware of any other places were we use a similar loop approach?

F.bilinear used a matrix multiplication, which created large intermediate tensors and introduced a performance regression (in comparison to the previous pure Python implementation). @tom fixed this by using a loop approach some time ago in this PR.

2 Likes