Hello folks.
I wondered if anyone else out there was using U-Net in Pytorch and was having trouble with memory usage? I’ve tried a few approaches to reduce the memory I’m using and I’ve hit a rather odd problem.
The background - I’m trying to train multiclass u-net to predict 5 classes from a 3D image in a very similar vein to this paper - https://arxiv.org/pdf/1606.06650.pdf . My images are 320x150x26 pixels in size. I’m using half precision and a batch size of 2. The whole thing is taking forever and I’d like to increase the batch size.
My first U-Net architecture looks like this:
class NetULarge(nn.Module):
def __init__(self, dtype=torch.float16):
super(NetU, self).__init__()
self.n_channels = 1
self.n_classes = 5 # 5 probabilities for each of the neurons or background. TODO - might need to change this to 4?
self.bilinear = True
self.inc = mp.DoubleConv(1, 64, dtype=dtype)
self.down1 = mp.Down(64, 128, dtype=dtype)
self.down2 = mp.Down(128, 256, dtype=dtype)
self.down3 = mp.Down(256, 512, dtype=dtype)
self.down4 = mp.Down(512, 512, dtype=dtype)
self.up1 = mp.Up(1024, 256, self.bilinear, dtype=dtype)
self.up2 = mp.Up(512, 128, self.bilinear, dtype=dtype)
self.up3 = mp.Up(256, 64, self.bilinear, dtype=dtype)
self.up4 = mp.Up(128, 64, self.bilinear, dtype=dtype)
self.outc = mp.OutConv(64, self.n_classes, dtype=dtype)
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
out = self.outc(x)
return out
I decided to see how may parameters this has using the following code:
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
I get 40158853 - that’s quite a lot and more than the paper in question, so I thought I’d reduce the size of the model to the following:
class NetU(nn.Module):
def __init__(self, dtype=torch.float16):
super(NetU, self).__init__()
self.n_channels = 1
self.n_classes = 5 # 5 probabilities for each of the neurons or background. TODO - might need to change this to 4?
self.bilinear = True
self.inc = mp.DoubleConv(1, 64, dtype=dtype)
self.down1 = mp.Down(64, 128, dtype=dtype)
self.down2 = mp.Down(128, 256, dtype=dtype)
self.down3 = mp.Down(256, 256, dtype=dtype)
self.up1 = mp.Up(512, 256, self.bilinear, dtype=dtype)
self.up2 = mp.Up(384, 128, self.bilinear, dtype=dtype)
self.up3 = mp.Up(192, 64, self.bilinear, dtype=dtype)
self.outc = mp.OutConv(64, self.n_classes, dtype=dtype)
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x = self.up1(x4, x3)
x = self.up2(x, x2)
x = self.up3(x, x1)
out = self.outc(x)
return out
Now I get 14496517 parameters - a significant reduction.
This does not reduce the memory by any reasonable amount however. I started by using nvidia-smi to watch the memory usage on my GPU. Both nets would peak at around 7000MiB. Occasionally, both would dip to around 1500MiB or so but there is little difference between them.
I decided to look a bit deeper and found some code online that used forward and backward hooks to list what each part of the model was using. I won’t list them all here as the traces are quite long but I did notice something interesting.
In the smaller network
layer_idx call_idx layer_type exp hook_type mem_all mem_cached
0 0 0 NetU 0 pre 87474176 106954752
1 1 1 DoubleConv 0 pre 87474176 106954752
2 2 2 Sequential 0 pre 87474176 106954752
3 3 3 Conv3d 0 pre 87474176 106954752
4 3 4 Conv3d 0 fwd 406962176 427819008
.....
670 91 670 BatchNorm3d 0 bwd 4954478592 5471469568
and in the larger network
layer_idx call_idx layer_type exp hook_type mem_all mem_cached
0 0 0 NetU 0 pre 34027520 46137344
1 1 1 DoubleConv 0 pre 34027520 46137344
2 2 2 Sequential 0 pre 34027520 46137344
3 3 3 Conv3d 0 pre 34027520 46137344
4 3 4 Conv3d 0 fwd 353515520 367001600
....
966 62 966 Up 0 bwd 5632503808 5937037312
The larger network has a few more steps to go through but although there is a reduction in the overall memory used between the networks, both peak at the same sort of level. In the smaller network case though, the cached memory stays pretty high, even when the total memory used is pretty low, e.g:
1563 45 1563 DoubleConv 0 bwd 3151531008 6121586688
This cached size seems close to the 7000MiB or so I was seeing with nvidia-smi
I’m not sure how best to attack this problem. I get that the backward pass has a lot of work to do, but why should the cache be so high and why am I not seeing a larger reduction, given the model is much smaller?
Cheers
B