U-Net model reduced in size, still takes up a lot of memory. Wondering what to do next?

Hello folks.

I wondered if anyone else out there was using U-Net in Pytorch and was having trouble with memory usage? I’ve tried a few approaches to reduce the memory I’m using and I’ve hit a rather odd problem.

The background - I’m trying to train multiclass u-net to predict 5 classes from a 3D image in a very similar vein to this paper - https://arxiv.org/pdf/1606.06650.pdf . My images are 320x150x26 pixels in size. I’m using half precision and a batch size of 2. The whole thing is taking forever and I’d like to increase the batch size.

My first U-Net architecture looks like this:

class NetULarge(nn.Module):
    def __init__(self, dtype=torch.float16):
        super(NetU, self).__init__()
        self.n_channels = 1
        self.n_classes = 5 # 5 probabilities for each of the neurons or background. TODO - might need to change this to 4?
        self.bilinear = True

        self.inc = mp.DoubleConv(1, 64, dtype=dtype)
        self.down1 = mp.Down(64, 128, dtype=dtype)
        self.down2 = mp.Down(128, 256, dtype=dtype)
        self.down3 = mp.Down(256, 512, dtype=dtype)
        self.down4 = mp.Down(512, 512, dtype=dtype)
        self.up1 = mp.Up(1024, 256, self.bilinear, dtype=dtype)
        self.up2 = mp.Up(512, 128, self.bilinear, dtype=dtype)
        self.up3 = mp.Up(256, 64, self.bilinear, dtype=dtype)
        self.up4 = mp.Up(128, 64, self.bilinear, dtype=dtype)
        self.outc = mp.OutConv(64, self.n_classes, dtype=dtype)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        out = self.outc(x)
        return out

I decided to see how may parameters this has using the following code:

pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

I get 40158853 - that’s quite a lot and more than the paper in question, so I thought I’d reduce the size of the model to the following:

class NetU(nn.Module):
    def __init__(self, dtype=torch.float16):
        super(NetU, self).__init__()
        self.n_channels = 1
        self.n_classes = 5 # 5 probabilities for each of the neurons or background. TODO - might need to change this to 4?
        self.bilinear = True

        self.inc = mp.DoubleConv(1, 64, dtype=dtype)
        self.down1 = mp.Down(64, 128, dtype=dtype)
        self.down2 = mp.Down(128, 256, dtype=dtype)
        self.down3 = mp.Down(256, 256, dtype=dtype)
        self.up1 = mp.Up(512, 256, self.bilinear, dtype=dtype)
        self.up2 = mp.Up(384, 128, self.bilinear, dtype=dtype)
        self.up3 = mp.Up(192, 64, self.bilinear, dtype=dtype)
        self.outc = mp.OutConv(64, self.n_classes, dtype=dtype)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
     
        x = self.up1(x4, x3)
        x = self.up2(x, x2)
        x = self.up3(x, x1)
        out = self.outc(x)
        return out

Now I get 14496517 parameters - a significant reduction.

This does not reduce the memory by any reasonable amount however. I started by using nvidia-smi to watch the memory usage on my GPU. Both nets would peak at around 7000MiB. Occasionally, both would dip to around 1500MiB or so but there is little difference between them.

I decided to look a bit deeper and found some code online that used forward and backward hooks to list what each part of the model was using. I won’t list them all here as the traces are quite long but I did notice something interesting.

In the smaller network

      layer_idx  call_idx   layer_type  exp hook_type     mem_all  mem_cached
0             0         0         NetU    0       pre    87474176   106954752
1             1         1   DoubleConv    0       pre    87474176   106954752
2             2         2   Sequential    0       pre    87474176   106954752
3             3         3       Conv3d    0       pre    87474176   106954752
4             3         4       Conv3d    0       fwd   406962176   427819008
.....
670          91       670  BatchNorm3d    0       bwd  4954478592  5471469568

and in the larger network

          layer_idx  call_idx   layer_type  exp hook_type     mem_all  mem_cached
0             0         0         NetU    0       pre    34027520    46137344
1             1         1   DoubleConv    0       pre    34027520    46137344
2             2         2   Sequential    0       pre    34027520    46137344
3             3         3       Conv3d    0       pre    34027520    46137344
4             3         4       Conv3d    0       fwd   353515520   367001600
....
966          62       966           Up    0       bwd  5632503808  5937037312

The larger network has a few more steps to go through but although there is a reduction in the overall memory used between the networks, both peak at the same sort of level. In the smaller network case though, the cached memory stays pretty high, even when the total memory used is pretty low, e.g:

1563         45      1563   DoubleConv    0       bwd  3151531008  6121586688

This cached size seems close to the 7000MiB or so I was seeing with nvidia-smi

I’m not sure how best to attack this problem. I get that the backward pass has a lot of work to do, but why should the cache be so high and why am I not seeing a larger reduction, given the model is much smaller?
Cheers
B

Couple of hints.
Despite you use fp16 gradients are still computed as fp32.
fp16 is supposed to save MAC ops but you won’t gain much memory. In addition, not all the gpus were optimized to work with fp16. It helps in gen 30XX and quaddros, but not in 10XX and don’t remember in gen 20XX. Check yours.

Another core question is how you coded down and up.

  • Reducing the filter size is critical to reduce the memory.
  • Replacing upconvolution by interpolation
  • Removing maxpool in favor of a convolution with stride
  • Use in-place relu
    IMO there is not much more u can do.

Hello there,
Thanks for the tips. Reducing to float16 is essential as the memory reduction is significant - enough of a difference to make it run at all. Interesting about the gradient though. I’m on a 2080Ti at the moment. I see what support it has.

Filter size I can change. I believe I’m using interpolation but I’ll check again. Maxpool is still used but I can replace with a stride. In-place relu I’ve not tried. I’ll give that a shot.

I’m still very surprised that the reduction in parameters has had very little effect on the memory use. Very odd indeed.
B

So I added all these features, and reduced the network even more


class NetUSmall(nn.Module):
    def __init__(self, dtype=torch.float16):
        super(NetUSmall, self).__init__()
        self.n_channels = 1
        self.n_classes = 5
        self.trilinear = True
        self.inc = mp.DoubleConv(1, 64, dtype=dtype)
        self.down1 = mp.Down(64, 128, dtype=dtype)
        self.down2 = mp.Down(128, 128, dtype=dtype)
        self.up1 = mp.Up(256, 128, trilinear=self.trilinear, dtype=dtype)
        self.up2 = mp.Up(192, 64, trilinear=self.trilinear, dtype=dtype)
        self.outc = mp.OutConv(64, self.n_classes, dtype=dtype)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x = self.up1(x3, x2)
        x = self.up2(x, x1)
        out = self.outc(x)
        return out

This has resulted in only a saving of 100MiB or so. This seems very odd to me. I think the cache is too large - the memory usage peaks during the backward pass but surely there should be more of a reduction?
B

So I managed to increase the batch size to 4 by reducing the number of filters in the first layer:

class NetU(nn.Module):
    ''' U-Net code.'''

    def __init__(self, dtype=torch.float16):
        super(NetU, self).__init__()
        self.n_channels = 1
        self.n_classes = 5
        self.trilinear = True
        self.inc = mp.DoubleConv(1, 32, dtype=dtype)
        self.down1 = mp.Down(32, 64, dtype=dtype)
        self.down2 = mp.Down(64, 128, dtype=dtype)
        self.down3 = mp.Down(128, 256, dtype=dtype)
        self.up1 = mp.Up(384, 128, trilinear=self.trilinear, dtype=dtype)
        self.up2 = mp.Up(192, 64, trilinear=self.trilinear, dtype=dtype)
        self.up3 = mp.Up(96, 32, trilinear=self.trilinear, dtype=dtype)
        self.outc = mp.OutConv(32, self.n_classes, dtype=dtype)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x = self.up1(x4, x3)
        x = self.up2(x, x2)
        x = self.up3(x, x1)
        out = self.outc(x)
        return out

Unfortunately, it’s still only a minor increase. The memory seems to hover around 20% of the total which should mean I have a fair bit of headspace, but there appears to be a blip early in the training loop where memory spikes quite highly. Not sure what is going on there. I’ll investigate a little more.
B