AdaIN Out of Memory

I have a pretty easy network architecture with AdaIN, it has only about 11,022,436 parameters but raises RuntimeError: CUDA out of memory. I’ve tried on GPU with 12GB and 24GB memory but both CUDA OOM. Can’t figure out why.

class generator_adain(nn.Module):
    def __init__(self, g_dim, prob_dim, aug_dim, z_dim):
        super(generator_adain, self).__init__()
        self.g_dim = g_dim
        self.prob_dim = prob_dim
        self.z_dim = z_dim
        self.aug_dim = aug_dim

        style_codes = torch.zeros((self.prob_dim, self.z_dim))
        color_codes = torch.zeros((self.prob_dim, self.aug_dim, self.z_dim))
        self.style_codes = nn.Parameter(style_codes)
        self.color_codes = nn.Parameter(color_codes)
        nn.init.constant_(self.style_codes, 0.0)
        nn.init.constant_(self.color_codes, 0.0)

        self.conv_0 = nn.Conv3d(1,             self.g_dim*1,  5, stride=1, dilation=1, padding=2, bias=True)
        self.conv_1 = nn.Conv3d(self.g_dim*1,  self.g_dim*2,  5, stride=1, dilation=2, padding=4, bias=True)
        self.conv_2 = nn.Conv3d(self.g_dim*2,  self.g_dim*4,  5, stride=1, dilation=2, padding=4, bias=True)
        self.conv_3 = nn.Conv3d(self.g_dim*4,  self.g_dim*8,  5, stride=1, dilation=1, padding=2, bias=True)
        self.conv_4 = nn.Conv3d(self.g_dim*8,  self.g_dim*4,  5, stride=1, dilation=1, padding=2, bias=True)
        self.adain_0 = AdaIN3D(z_dim, self.g_dim * 1)
        self.adain_1 = AdaIN3D(z_dim, self.g_dim * 2)
        self.adain_2 = AdaIN3D(z_dim, self.g_dim * 4)
        self.adain_3 = AdaIN3D(z_dim, self.g_dim * 8)
        self.adain_4 = AdaIN3D(z_dim, self.g_dim * 4)

        self.conv_5_g = nn.ConvTranspose3d(self.g_dim*4,   self.g_dim*2, 4, stride=2, padding=1, bias=True)
        self.conv_6_g = nn.Conv3d(self.g_dim*2,            self.g_dim*2, 3, stride=1, padding=1, bias=True)
        self.conv_7_g = nn.ConvTranspose3d(self.g_dim*2,   self.g_dim*1, 4, stride=2, padding=1, bias=True)
        self.conv_8_g = nn.Conv3d(self.g_dim*1,            1,            3, stride=1, padding=1, bias=True)
        self.adain_5_g = AdaIN3D(z_dim, self.g_dim * 2)
        self.adain_6_g = AdaIN3D(z_dim, self.g_dim * 2)
        self.adain_7_g = AdaIN3D(z_dim, self.g_dim * 1)

        self.conv_5_c = nn.ConvTranspose3d(self.g_dim*4,  self.g_dim*2, 4, stride=2, padding=1, bias=True)
        self.conv_6_c = nn.Conv3d(self.g_dim*2,           self.g_dim*2, 3, stride=1, padding=1, bias=True)
        self.conv_7_c = nn.ConvTranspose3d(self.g_dim*2,  self.g_dim*1, 4, stride=2, padding=1, bias=True)
        self.conv_8_c = nn.Conv3d(self.g_dim*1,           3,            3, stride=1, padding=1, bias=True)
        self.adain_5_c = AdaIN3D(z_dim, self.g_dim * 2)
        self.adain_6_c = AdaIN3D(z_dim, self.g_dim * 2)
        self.adain_7_c = AdaIN3D(z_dim, self.g_dim * 1)

    def forward(self, voxels, z_geometry, z_color, mask_, is_training=False):
        out = voxels
        mask = F.interpolate(mask_, scale_factor=4, mode='nearest')

        # backbone
        out = self.conv_0(out)
        out = F.leaky_relu(out, negative_slope=0.02, inplace=True)
        out = self.adain_0(out, z_geometry)

        out = self.conv_1(out)
        out = F.leaky_relu(out, negative_slope=0.02, inplace=True)
        out = self.adain_1(out, z_geometry)

        out = self.conv_2(out)
        out = F.leaky_relu(out, negative_slope=0.02, inplace=True)
        out = self.adain_2(out, z_geometry)

        out = self.conv_3(out)
        out = F.leaky_relu(out, negative_slope=0.02, inplace=True)
        out = self.adain_3(out, z_geometry)

        out = self.conv_4(out)
        out = F.leaky_relu(out, negative_slope=0.02, inplace=True)
        out = self.adain_4(out, z_geometry)

        # geometry
        out_g = self.conv_5_g(out)
        out_g = F.leaky_relu(out_g, negative_slope=0.02, inplace=True)
        out_g = self.adain_5_g(out_g, z_geometry)

        out_g = self.conv_6_g(out_g)
        out_g = F.leaky_relu(out_g, negative_slope=0.02, inplace=True)
        out_g = self.adain_6_g(out_g, z_geometry)

        out_g = self.conv_7_g(out_g)
        out_g = F.leaky_relu(out_g, negative_slope=0.02, inplace=True)
        out_g = self.adain_7_g(out_g, z_geometry)

        out_g = self.conv_8_g(out_g)
        out_g = torch.max(torch.min(out_g, out_g * 0.002 + 0.998), out_g * 0.002)
        out_g = out_g * mask

        # color
        out_c = self.conv_5_c(out)
        out_c = F.leaky_relu(out_c, negative_slope=0.02, inplace=True)
        out_c = self.adain_5_c(out_c, z_color)

        out_c = self.conv_6_c(out_c)
        out_c = F.leaky_relu(out_c, negative_slope=0.02, inplace=True)
        out_c = self.adain_6_c(out_c, z_color)

        out_c = self.conv_7_c(out_c)
        out_c = F.leaky_relu(out_c, negative_slope=0.02, inplace=True)
        out_c = self.adain_7_c(out_c, z_color)

        out_c = self.conv_8_c(out_c, z_color)
        out_c = torch.max(torch.min(out_c, out_c * 0.002 + 0.998), out_c * 0.002)

        return out_g, out_c


class AdaIN3D(nn.Module):
    def __init__(self, z_dim, g_dim):
        super(AdaIN3D, self).__init__()
        self.eps = 1e-5
        self.mapping = nn.Conv3d(z_dim, g_dim, 1, stride=1, padding=0)

    def forward(self, x, y):
        y = self.mapping(y)
        mean_x = torch.mean(x, dim=(2, 3, 4), keepdim=True)
        mean_y = torch.mean(y, dim=(2, 3, 4), keepdim=True)

        std_x = torch.sqrt(torch.var(x, dim=(2, 3, 4), keepdim=True) + self.eps)
        std_y = torch.sqrt(torch.var(y, dim=(2, 3, 4), keepdim=True) + self.eps)

        out = (x - mean_x) / std_x * std_y + mean_y

        return out

model = generator_adain(g_dim=32, prob_dim=32, aug_dim=6, z_dim=8).cuda()
    print("Number of parameters: {:,}".format(sum(p.numel() for p in model.parameters() if p.requires_grad)))

vox = torch.rand((1, 1, 64, 64, 64)).to(torch.device('cuda'))
z_g = torch.rand((1, 8, 1, 1, 1)).to(torch.device('cuda'))
z_c = torch.rand((1, 8, 1, 1, 1)).to(torch.device('cuda'))
mask = torch.rand((1, 1, 64, 64, 64)).to(torch.device('cuda'))

start = time.time()
out = model(vox, z_g, z_c, mask)
print(time.time() - start)
print(out[0].shape)
print(out[1].shape)

Error

out = (x - mean_x) / std_x * std_y + mean_y
RuntimeError: CUDA out of memory. Tried to allocate 2.00 GiB

11M parameters are quite a lot, are you trying to run it with only one voxel? Did you try with voxel smaller than 64^3?

yes I’m trying to run it with only one voxel, I guess its too many parameters, it works if I try with smaller voxel size

In my experience running networks on volumetric data takes much more memory compared to images. You can reduce the precision of the network to float16 from float32 or use automatic mixed precision. Another approach would be to simplify network architecture.