Very slow inference instead of OoM

I have an ESRGAN model that takes a lot of VRAM depending on how large the input image is. I noticed that VRAM scales roughly linearly with the number of pixels in the image until my VRAM is maxed out. However, to my surprise, I found that pytorch is still able to upscale larger images that (judging from the number of pixels) it does not have enough VRAM for. While the inference succeeds, it also takes longer than expected, again judging from the number of pixels in the input image.

So my question is: Does pytorch implement clever tricks to make inference succeed with low VRAM, e.g. by using algorithms that make different time-space tradeoffs?

If so, is there a way for me to turn them off? In my use case, I actually want pytorch to OoM error if the fastest inference isn’t possible, instead of slowly making its way through.

PyTorch will clear the memory cache and rerun the allocation once it’s running into an OOM. If the second allocation also fails the OOM error will be reraised to the user. Based on your description it sounds as if you might indeed be clearing the cache in each iteration.

Is there a way to force it to raise an OOM error when the first allocation fails? Or maybe even a way to beforehand get the amount of memory needed for inference?

No, I’m not aware of a way to change this behavior in user code without changing and rebuilding the backend. You could try to delete unneeded intermediates manually in your code and also try to avoid creating memory fragmentation (e.g. by using the largest inputs and decreasing it in case you are working with variable input data). However, make sure the cache clearing is indeed what slows down the code by profiling it via e.g. Nsight Systems which should show the cudaFree calls.