Hello all I was wondering if someone could explain to me why this snippet of code causes my GPU to run out of memory:
outputs = []
slices = torch.rand(100, 3, 512, 512)
batch_size=4
for batch_idx in range(0, 100, batch_size):
slice_ = slices[batch_idx: batch_idx + batch_size].cuda()
u_net_output = unet(slice_)
outputs.append(u_net_output.data.cpu())
Since I’ve moved the output of unet
to the cpu, doesn’t that mean that it does not reside on the GPU anymore? Why is it causing OOM errors?