Ok, I found a solution that works for me:
On startup I measure the free memory on the GPU, take 80% of that, create a variable this big and put it on GPU. Directly after doing that, I override it with a small value.
While the process is running, the GPU has still 80% memory blocked and pytorch is using this space.
Code:
import os
import torch
def check_mem():
mem = os.popen('"<path\to\NVSMI>\nvidia-smi" --query-gpu=memory.total,memory.used --format=csv,nounits,noheader').read().split(",")
return mem
def main():
total, used = check_mem()
total = int(total)
used = int(used)
max_mem = int(total * 0.8)
block_mem = max_mem - used
x = torch.rand((256,1024,block_mem)).cuda()
x = torch.rand((2,2)).cuda()
#do things here