Out of memory error with nn.ModuleList consisting of dozens of small layers

I’m trying to combine “Faster Discovery of Neural Architectures by Searching for Paths in a Large Model” and “Learning Transferable Architectures for Scalable Image Recognition.” My approach requires 20 “Layer” objects, each of which consists of a nn.ModuleList of ~30 1x1 conv layers and a nn.ModuleList of ~15 3x3 conv layers as well as ~15 F.max&avg_pool2d’s. Each Layer object has various layers placed in somewhat parallel way, so that at each iteration the kind of layers to be active are decided by an external controller, and therefore the actual FLOPS and parameters to be relevant to the computation are smaller (4M vs. 14M for parameters). Inactive layers are not used in forward or backward pass, and only ~12 layers out of ~75 layers are active per Layer object. Though 20 Layer objects are used, the actual number of layers is about 40.

The problem is that even a single iteration cannot be done during backward pass due to an out of memory error (whether CPU only or GPU [K-80]). Inference works without problem. As there are hundreds of 1x1/3x3 layers in total despite of 14M total parameters, each layer has a very small number of parameters. The following is the entire code of Layer object if that helps. (Indentation somehow didn’t work well here.)

class Layer(nn.Module):
def __init__(self, in_ch1, in_ch2, out_ch):
    super(Layer, self).__init__()
    self.in_ch = [in_ch1, in_ch2]
    self.out_ch = out_ch
    if self.in_ch[0] != self.in_ch[1]:
        self.adjust = Factorized_adjustment(in_ch1, in_ch2)

    self.conv1 = nn.ModuleList()
    for i in range(2):
        for j in range(4):
            self.conv1 += self._make_layer(self.in_ch[0], kernel=1, padding=0,
                                           stride=1 if self.in_ch[0] >= self.out_ch else 2)
    for i in range(6):
        self.conv1 += self._make_layer(out_ch // 4, kernel=1, padding=0)
    for k in range(3):
        for i in range(2):
            for j in range(4):
                self.conv1 += self._make_layer(self.in_ch[0], kernel=1, padding=0,
                                               stride=1 if self.in_ch[0] >= self.out_ch else 2)

    self.conv3 = nn.ModuleList()
    for i in range(2):
        for j in range(4):
            self.conv3 += self._make_layer(self.in_ch[0], kernel=3, padding=1,
                                           stride=2 if self.in_ch[0] < self.out_ch else 1)
    for i in range(6):
        self.conv3 += self._make_layer(out_ch // 4, kernel=3, padding=1)

def _make_layer(self, in_ch, kernel, padding, stride=1, deconv=False):
    return [Conv(in_ch, self.out_ch // 4, kernel, padding=padding, stride=stride, deconv=deconv)]

def forward(self, h1, h2, code):

    c = [[[] for j in range(4)] for i in range(5)]

    for i in range(4):
        tmp = [code[4 * i] % 5, code[4 * i + 1] % (i + 2), code[4 * i + 2] % 5, code[4 * i + 3] % (i + 2)]
        if tmp[1] == tmp[3] and tmp[0] == tmp[2]:
            tmp[3] = (code[4 * i + 3] + 1) % (i + 2)
        c[tmp[1]][i].append(tmp[0])
        c[tmp[3]][i].append(tmp[2])

    params = 0
    scale = h1.size(2)
    if self.in_ch[0] > self.out_ch:
        scale = scale * 2
    elif self.in_ch[0] < self.out_ch:
        scale = scale // 2
    if args.cuda:
        zeros_out = V(torch.zeros(h1.size(0), self.out_ch // 4, scale, scale).cuda(),
                      volatile=not self.training, requires_grad=False)
    else:
        zeros_out = V(torch.zeros(h1.size(0), self.out_ch // 4, scale, scale), volatile=not self.training,
                      requires_grad=False)
    h = [h1, h2]
    del h1, h2
    if self.in_ch[0] != self.in_ch[1]:
        h[1], p = self.adjust(h[1], params)
        params += p
    sum = [zeros_out] * 4
    for i in range(2):
        for j in range(4):
            for k in c[i][j]:
                if k == 0:
                    if self.in_ch[0] > self.out_ch:
                        sum[j] = sum[j] + self.conv1[4 * i + j](F.upsample(h[i], scale_factor=2))
                    else:
                        sum[j] = sum[j] + self.conv1[4 * i + j](h[i])
                    params += self.in_ch[0] * self.out_ch // 4
                if k == 1:
                    if self.in_ch[0] > self.out_ch:
                        sum[j] = sum[j] + self.conv3[4 * i + j](F.upsample(h[i], scale_factor=2))
                    else:
                        sum[j] = sum[j] + self.conv3[4 * i + j](h[i])
                    params += 3 * 3 * self.in_ch[0] * self.out_ch // 4
                if k == 2:
                    if self.out_ch >= self.in_ch[0]:
                        sum[j] = sum[j] + self.conv1[14 + 4 * i + j](F.avg_pool2d(h[i], 3, padding=1, stride=1))
                    else:
                        sum[j] = sum[j] + self.conv1[14 + 4 * i + j](
                            F.upsample(F.avg_pool2d(h[i], 3, 1, 1), scale_factor=2))
                    params += self.in_ch[0] * self.out_ch // 4
                if k == 3:
                    if self.out_ch >= self.in_ch[0]:
                        sum[j] = sum[j] + self.conv1[22 + 4 * i + j](F.max_pool2d(h[i], 3, 1, 1))
                    else:
                        sum[j] = sum[j] + self.conv1[22 + 4 * i + j](
                            F.upsample(F.max_pool2d(h[i], 3, 1, 1), scale_factor=2))
                    params += self.in_ch[0] * self.out_ch // 4
                if k == 4:
                    if self.in_ch[0] > self.out_ch:
                        sum[j] = sum[j] + self.conv1[30 + 4 * i + j](F.upsample(h[i], scale_factor=2))
                    else:
                        sum[j] = sum[j] + self.conv1[30 + 4 * i + j](h[i])
                    params += self.in_ch[0] * self.out_ch // 4
    del h
    for i in range(6):
        if i == 0:
            j = 0
            k = 1
        elif i == 1:
            j = 0
            k = 2
        elif i == 2:
            j = 0
            k = 3
        elif i == 3:
            j = 1
            k = 2
        elif i == 4:
            j = 1
            k = 3
        else:
            j = 2
            k = 3
        for l in c[j][k]:
            if l == 0:
                sum[k] = sum[k] + self.conv1[8 + i](sum[j])
            params += self.out_ch * self.out_ch // 16
            if l == 1:
                sum[k] = sum[k] + self.conv3[8 + i](sum[j])
            params += 3 * 3 * self.out_ch * self.out_ch // 16
            if l == 2:
                sum[k] = sum[k] + F.leaky_relu(F.avg_pool2d(sum[j], 3, 1, 1),
                                       negative_slope=0.2)
            if l == 3:
                sum[k] = sum[k] + F.leaky_relu(F.max_pool2d(sum[j], 3, 1, 1),
                                       negative_slope=0.2)
            if l == 4:  # identity
                sum[k] = sum[k]
    output = torch.cat(sum, 1)
    del sum
    return output, params
1 Like