GPU memory lack even with model parallelism?

I was testing the U-GAT-IT model and found out that the model is to heavy for a single GPU even with batch size 1, and tried model parallelism.
I’m using TiTan X GPU, Cuda 10.0, Pytorch 1.2.0.

class ResnetGenerator(nn.Module):
    def __init__(self, input_nc, output_nc, ngf=64, n_blocks=6, img_size=256, light=False, split_gpus=True):
        assert(n_blocks >= 0)
        super(ResnetGenerator, self).__init__()
        self.input_nc = input_nc
        self.output_nc = output_nc
        self.ngf = ngf
        self.n_blocks = n_blocks
        self.img_size = img_size
        self.light = light
        self.split_gpus = split_gpus

        DownBlock = []
        DownBlock += [nn.ReflectionPad2d(3),
                      nn.Conv2d(input_nc, ngf, kernel_size=7, stride=1, padding=0, bias=False),
                      nn.InstanceNorm2d(ngf),
                      nn.ReLU(True)]

        # Down-Sampling
        n_downsampling = 2
        for i in range(n_downsampling):
            mult = 2**i
            DownBlock += [nn.ReflectionPad2d(1),
                          nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=0, bias=False),
                          nn.InstanceNorm2d(ngf * mult * 2),
                          nn.ReLU(True)]

        # Down-Sampling Bottleneck
        mult = 2**n_downsampling
        for i in range(n_blocks):
            DownBlock += [ResnetBlock(ngf * mult, use_bias=False)]

        # Class Activation Map
        self.gap_fc = nn.Linear(ngf * mult, 1, bias=False).to('cuda:1')
        self.gmp_fc = nn.Linear(ngf * mult, 1, bias=False).to('cuda:1')
        self.conv1x1 = nn.Conv2d(ngf * mult * 2, ngf * mult, kernel_size=1, stride=1, bias=True).to('cuda:1')
        self.relu = nn.ReLU(True).to('cuda:1')


        # Gamma, Beta block
        if self.light:
            FC = [nn.Linear(ngf * mult, ngf * mult, bias=False),
                  nn.ReLU(True),
                  nn.Linear(ngf * mult, ngf * mult, bias=False),
                  nn.ReLU(True)]
        else:
            FC = [nn.Linear(img_size // mult * img_size // mult * ngf * mult, ngf * mult, bias=False),
                  nn.ReLU(True),
                  nn.Linear(ngf * mult, ngf * mult, bias=False),
                  nn.ReLU(True)]
        self.gamma = nn.Linear(ngf * mult, ngf * mult, bias=False).to('cuda:1')
        self.beta = nn.Linear(ngf * mult, ngf * mult, bias=False).to('cuda:1')

        # Up-Sampling Bottleneck
        for i in range(n_blocks):
            setattr(self, 'UpBlock1_' + str(i+1), ResnetAdaILNBlock(ngf * mult, use_bias=False).to('cuda:1'))

        # Up-Sampling
        UpBlock2 = []
        for i in range(n_downsampling):
            mult = 2**(n_downsampling - i)
            UpBlock2 += [nn.Upsample(scale_factor=2, mode='nearest'),
                         nn.ReflectionPad2d(1),
                         nn.Conv2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=1, padding=0, bias=False),
                         ILN(int(ngf * mult / 2)),
                         nn.ReLU(True)]

        UpBlock2 += [nn.ReflectionPad2d(3),
                     nn.Conv2d(ngf, output_nc, kernel_size=7, stride=1, padding=0, bias=False),
                     nn.Tanh()]
            
        self.DownBlock = nn.Sequential(*DownBlock).to('cuda:0')
        self.FC = nn.Sequential(*FC).to('cuda:1')
        self.UpBlock2 = nn.Sequential(*UpBlock2).to('cuda:1')
        

    def forward(self, input):
        
        x = self.DownBlock(input)
        x = x.cuda(1)
        
        print('A')
        gap = torch.nn.functional.adaptive_avg_pool2d(x, 1)
        gap_logit = self.gap_fc(gap.view(x.shape[0], -1))
        gap_weight = list(self.gap_fc.parameters())[0]
        gap = x * gap_weight.unsqueeze(2).unsqueeze(3)

        gmp = torch.nn.functional.adaptive_max_pool2d(x, 1)
        gmp_logit = self.gmp_fc(gmp.view(x.shape[0], -1))
        gmp_weight = list(self.gmp_fc.parameters())[0]
        gmp = x * gmp_weight.unsqueeze(2).unsqueeze(3)

        cam_logit = torch.cat([gap_logit, gmp_logit], 1)
        x = torch.cat([gap, gmp], 1)
        x = self.relu(self.conv1x1(x))

        heatmap = torch.sum(x, dim=1, keepdim=True)
        
        if self.light:
            x_ = torch.nn.functional.adaptive_avg_pool2d(x, 1)
            x_ = self.FC(x_.view(x_.shape[0], -1))
        else:
            x_ = self.FC(x.view(x.shape[0], -1))
        
        gamma, beta = self.gamma(x_), self.beta(x_)
        gamma = gamma.cuda(1)
        beta = beta.cuda(1)
        for i in range(self.n_blocks):
            x = getattr(self, 'UpBlock1_' + str(i+1))(x, gamma, beta)
        
        
        out = self.UpBlock2(x)
        
        out = out.cuda(0)
        cam_logit = cam_logit.cuda(0)
        heatmap = heatmap.cuda(0)
        
        return out, cam_logit, heatmap

but still gives the following error

RuntimeError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 1; 11.91 GiB total capacity; 11.12 GiB already allocated; 8.56 MiB free; 160.91 MiB cached)

The funny thing is that by my code of model parallelism, the model gets through 1 iteration then gets the error, so it isn’t useless at all.
So I think there could be two possibilities

  1. Something is wrong with my Model Parallelism or GPU settings(e.g. memory dynamic allocation is not working)
  2. U-GAT-IT is super heavy so it needs even more memory

Could anyone answer which possibility is true? Any help is appreciated.

If the device runs out of memory in the second iteration, it might point towards some intermediate tensors or output tensors, which are still on the device, but could be freed.
E.g. if you are using the training loop as:

for data, target in loader:
    optimizer.zero_grad()
    output = model(data)
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()

the output tensor would still be on the device (as well as the loss tensor, which can be ignored, if it’s just a scalar). You could remove it via del output and rerun the code.

Also, the gradients will be accumulated and zeroed out in the next iteration, which would still use the memory.
If you are close to the OOM, you could remove them by setting all gradients to None via:

for param in model.parameters():
    param.grad = None