Is there a way to precompute the size of the computation graph a model will need?

Hi,

I am training a UNet3D model with MRI data, and the images and the model itself are quite large. To fit into the 24GB VRAM available to me I need to decrease the input image size significantly. I was wondering how much VRAM I would need to use the original image size, but I obviously can’t just try to run it as I will very quickly get an OOM error.
So as a solution I was wondering if there was a way to compute how much VRAM the model’s computation graph and the loaded data will need (taking into account the batch size)?

This util. might be helpful using TorchDispatchMode to measure the memory usage without running the actual model.