Memory Leak when Training Dynamic Network

My codes will get ‘CUDA out of memory’ after several iterations. I am trying to train a network with random width(number of channels). Take MobileNet as an example, in each iteration, I will forward four random width MobileNet(e.g. 0.25x, 0.4x, 0.7, 1.0x). The smaller network shares weights with bigger network(0.25x shares the first half weights of 0.5x). I accumulate the gradients and do optimizer.step() once to update weights.

The largest network(1.0x) will only occupy 4GB memory when normal training. However, with the random width training, it will cause ‘CUDA out of memory’(it varies from 700MB to 11000MB, and after some iterations it will out of memory). My GPU has 11GB memory in total.

I am guessing some tensors or graphs are not freed. But I don’t know how to debug it. Could someone help?

Are you storing some tensors in a some container without detaching them?
Could you post a (small) reproducible code snippet, so that we can have a look?

Hi, my training codes are as follows

model.train()
for batch_idx, (input, target) in enumerate(loader):
    target = target.cuda(non_blocking=True)
    optimizer.zero_grad()
    width_mult_list = []
    # first do max_width and max_resolution
    max_width = FLAGS.width_mult_range[1]
    model.apply(lambda m: setattr(m, 'width_mult', max_width))
    max_output = model(input)
    loss = torch.mean(criterion(max_output, target))
    loss.backward()
    max_output = max_output.detach()
    # do other widths and resolution
    min_width = FLAGS.width_mult_range[0]
    width_mult_list = [min_width]
    for i in range(2):
        sampled_width = random.uniform(FLAGS.width_mult_range[0], FLAGS.width_mult_range[1])
        width_mult_list.append(sampled_width)
    for width_mult in sorted(width_mult_list, reverse=True):
        model.apply(
            lambda m: setattr(m, 'width_mult', width_mult))
        resolution = FLAGS.resolution_list[random.randint(0,3)]
        subinput = F.interpolate(input, (resolution, resolution), mode='bilinear', align_corners=True)
        output = model(subinput)
        loss = torch.nn.KLDivLoss(reduction='batchmean')(F.log_softmax(output, dim=1), F.softmax(max_output, dim=1))
        loss.backward()
    optimizer.step()

What does the model.apply(lambda m:...) method do?
Are you reinitializing some parameters or switching some behavior internally?

Hi @ptrblck , the model.apply(...) is used to set the ‘width_mult’ attribute for different widths. It is related to my custom convolution. I define my custom convolution as follows

class USConv2d(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, depthwise=False, bias=True,
                 us=[True, True], ratio=[1, 1]):
        in_channels_max = in_channels
        out_channels_max = out_channels
        if us[0]:
            in_channels_max = int(make_divisible(
                in_channels
                * FLAGS.width_mult
                / ratio[0]) * ratio[0])
        if us[1]:
            out_channels_max = int(make_divisible(
                out_channels
                * FLAGS.width_mult
                / ratio[1]) * ratio[1])
        groups = in_channels_max if depthwise else 1
        super(USConv2d, self).__init__(
            in_channels_max, out_channels_max,
            kernel_size, stride=stride, padding=padding, dilation=dilation,
            groups=groups, bias=bias)
        self.depthwise = depthwise
        self.in_channels_basic = in_channels
        self.out_channels_basic = out_channels
        self.width_mult = None
        self.us = us
        self.ratio = ratio

    def forward(self, input):
        in_channels = self.in_channels_basic
        out_channels = self.out_channels_basic
        if self.us[0]:
            in_channels = int(make_divisible(
                self.in_channels_basic
                * self.width_mult
                / self.ratio[0]) * self.ratio[0])
        if self.us[1]:
            out_channels = int(make_divisible(
                self.out_channels_basic
                * self.width_mult
                / self.ratio[1]) * self.ratio[1])
        self.groups = in_channels if self.depthwise else 1
        if self.bias is not None:
            y = nn.functional.conv2d(
                input, self.weight[:out_channels, :in_channels, :, :], self.bias[:out_channels], self.stride, self.padding,
                self.dilation, self.groups)
        else:
            y = nn.functional.conv2d(
                input, self.weight[:out_channels, :in_channels, :, :], self.bias, self.stride,
                self.padding, self.dilation, self.groups)
        if getattr(FLAGS, 'conv_averaged', False):
            y = y * (max(self.in_channels_list)/self.in_channels)
        return y

The width_mult is to calculate the number of channels used used for current width. Similarly, I define the custom BN and Linear layer. Codes are modified from the slimmable networks

I assume your code works for some iterations and throws an OOM error after a while or is your code running out of memory in the first iterations for the accumulated gradients?

Do you also see the memory leak using the repo code?
Are you accumulating the gradients for each different width?
If you don’t accumulate the gradients, do you still see the memory leak?

Hi @ptrblck, the code works for some iterations and throws the OOM error after a while.

  1. The repo code is kind of different from mine. It is fixed to four width(0.25x, 0.5x, 0.75x, 1.0x), while I am randomly sampling four widths in each iteration. One thing weird is that, if I set it to four fixed widths, I won’t have the OOM error. And the memory cost will be constant 4000MB after the first iteration. However, with randomly sampled width, the memory cost jumps around from 1000MB to 10000MB, and get OOM after some iterations.

  2. Yes, I am accumulating the gradients for each different width and update the weights at the end of the iteration. I tried not to accumulate the gradients(by adding optimizer.step() and optimizer.zero_grad() after each loss.backward(); I don’t know if this is what you mean), I still have the OOM error.

It’s interesting to see that apparently randomly sampling the width creates this issue.
Could you create an executable code snippet, so that we could debug it?

Hi @ptrblck, I’d love to. But I am not sure that you can debug it with some code snippets because it includes some other modules(like the custom convolution). If possible, I can send you the whole project for debugging.

Another thing is that when I try to run on one single GPU(previously 4 GPU), I won’t have OOM error.

Could you try to narrow down the issue a bit?
E.g. remove the custom convolution for multi GPU training and see, if you still get the OOM error.
Swapping custom classes for vanilla PyTorch ones might give us more information, which module leaks the memory.
Also, are you using nn.DataParallel, DDP or a custom implementation?

Hi @ptrblck ,but the custom convolution is where I implement the random width network(the model is stacked by these operations). If replaced with vanilla ones, then it is standard network.

Yes, I use nn.DataParallel.

Hi @ptrblck , do you have any idea about this issue or how could I debug it.

@TaojiannanYang He did. Simplify until you’ve isolated the code that causes OOM and make it easily executable so they can diagnose.

Don’t ask them to debug your project. That’s your job!

Hi @ptrblck, I think I find out where went wrong. I set torch.backends.cudnn.baenchmark = True. This way, the cudnn will look for the optimal algorithm for that particular configuration. However, in my codes, the network are varying in each iteration so it will lead to this problem. The reason why fixed width is OK is that it can find the optimal algorithm for all 4 widths after some iterations. Details can be refered here

Thanks a lot for your help!

1 Like

Good to hear you’ve isolated this issue! :slight_smile: