I’ve tried running it but it’s hitting assertion in line assert out.size(2) == prev_x.size(2). What’s the shape of the input that you are using?
assert out.size(2) == prev_x.size(2)
input size (572x572) exactly as it is the paper https://arxiv.org/pdf/1505.04597.pdf…
I pushed my entire code on git https://github.com/devansh20la/unet/blob/master/Unet.py. It’s been running all night on our lab’s devbox but I only got to 100 epoch on training size of 3000 images
Thank you so much
at batch size of 7, it’s taking 5GB for me.
5675MiB / 16273MiB
You wrote the network pretty efficiently, i dont think there’s memory improvements possible.
Try adding this flag for speed improvments: