Cannot perform backward pass on large image using ModelParallel regardless of batch size

Hi

I am trying to work with a large image 512x512 witha batch size of 5 using a very memory hungry UNet3+ architecture. I am trying to run this over 4 16GB GPUs. The issue is the decoder requires 16 upsamples of a layer which was causing OOM errors on the forward pass . To work around this I isolated that step to a single GPU and I can get out of the forward pass. But I am getting the following error

 Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 9.00 GiB (GPU 3; 15.78 GiB total capacity; 4.13 GiB already allocated; 3.79 GiB free; 11.02 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

I am confused because the as to why so much is still being allocated by pytorch. Also is there a better way to break up the upsampling?

my model is the following

class MP_UNet3Plus(nn.Module):
    def __init__(self, n_channels=3, n_classes=1, bilinear=True, feature_scale=4,
                 is_deconv=True, is_batchnorm=True):
        super(MP_UNet3Plus, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear
        self.feature_scale = feature_scale
        self.is_deconv = is_deconv
        self.is_batchnorm = is_batchnorm
        filters = [64, 128, 256, 512, 1024]

        ## -------------Encoder--------------
        self.conv1 = unetConv2(self.n_channels, filters[0], self.is_batchnorm).to(DEVICE_0)
        self.maxpool1 = nn.MaxPool2d(kernel_size=2).to(DEVICE_0)

        self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm).to(DEVICE_0)
        self.maxpool2 = nn.MaxPool2d(kernel_size=2).to(DEVICE_0)

        self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm).to(DEVICE_0)
        self.maxpool3 = nn.MaxPool2d(kernel_size=2).to(DEVICE_1)

        self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm).to(DEVICE_1)
        self.maxpool4 = nn.MaxPool2d(kernel_size=2).to(DEVICE_1)

        self.conv5 = unetConv2(filters[3], filters[4], self.is_batchnorm).to(DEVICE_1)

        ## -------------Decoder--------------
        self.CatChannels = filters[0]
        self.CatBlocks = 5
        self.UpChannels = self.CatChannels * self.CatBlocks

        '''stage 4d'''
        # h1->320*320, hd4->40*40, Pooling 8 times
        self.h1_PT_hd4 = nn.MaxPool2d(8, 8, ceil_mode=True).to(DEVICE_1)
        self.h1_PT_hd4_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1).to(DEVICE_1)
        self.h1_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels).to(DEVICE_1)
        self.h1_PT_hd4_relu = nn.ReLU(inplace=True).to(DEVICE_1)

        # h2->160*160, hd4->40*40, Pooling 4 times
        self.h2_PT_hd4 = nn.MaxPool2d(4, 4, ceil_mode=True).to(DEVICE_1)
        self.h2_PT_hd4_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1).to(DEVICE_1)
        self.h2_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels).to(DEVICE_1)
        self.h2_PT_hd4_relu = nn.ReLU(inplace=True).to(DEVICE_1)

        # h3->80*80, hd4->40*40, Pooling 2 times
        self.h3_PT_hd4 = nn.MaxPool2d(2, 2, ceil_mode=True).to(DEVICE_1)
        self.h3_PT_hd4_conv = nn.Conv2d(filters[2], self.CatChannels, 3, padding=1).to(DEVICE_1)
        self.h3_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels).to(DEVICE_1)
        self.h3_PT_hd4_relu = nn.ReLU(inplace=True).to(DEVICE_1)

        # h4->40*40, hd4->40*40, Concatenation
        self.h4_Cat_hd4_conv = nn.Conv2d(filters[3], self.CatChannels, 3, padding=1).to(DEVICE_1)
        self.h4_Cat_hd4_bn = nn.BatchNorm2d(self.CatChannels).to(DEVICE_1)
        self.h4_Cat_hd4_relu = nn.ReLU(inplace=True).to(DEVICE_1)

        # hd5->20*20, hd4->40*40, Upsample 2 times
        self.hd5_UT_hd4 = nn.Upsample(scale_factor=2, mode='bilinear').to(DEVICE_1)  # 14*14
        self.hd5_UT_hd4_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1).to(DEVICE_1)
        self.hd5_UT_hd4_bn = nn.BatchNorm2d(self.CatChannels).to(DEVICE_1)
        self.hd5_UT_hd4_relu = nn.ReLU(inplace=True).to(DEVICE_1)

        # fusion(h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4)
        self.conv4d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1).to(DEVICE_1)  # 16
        self.bn4d_1 = nn.BatchNorm2d(self.UpChannels).to(DEVICE_1)
        self.relu4d_1 = nn.ReLU(inplace=True).to(DEVICE_1)

        '''stage 3d'''
        # h1->320*320, hd3->80*80, Pooling 4 times
        self.h1_PT_hd3 = nn.MaxPool2d(4, 4, ceil_mode=True).to(DEVICE_1)
        self.h1_PT_hd3_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1).to(DEVICE_1)
        self.h1_PT_hd3_bn = nn.BatchNorm2d(self.CatChannels).to(DEVICE_1)
        self.h1_PT_hd3_relu = nn.ReLU(inplace=True).to(DEVICE_1)

        # h2->160*160, hd3->80*80, Pooling 2 times
        self.h2_PT_hd3 = nn.MaxPool2d(2, 2, ceil_mode=True).to(DEVICE_1)
        self.h2_PT_hd3_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1).to(DEVICE_1)
        self.h2_PT_hd3_bn = nn.BatchNorm2d(self.CatChannels).to(DEVICE_1)
        self.h2_PT_hd3_relu = nn.ReLU(inplace=True).to(DEVICE_1)

        # h3->80*80, hd3->80*80, Concatenation
        self.h3_Cat_hd3_conv = nn.Conv2d(filters[2], self.CatChannels, 3, padding=1).to(DEVICE_1)
        self.h3_Cat_hd3_bn = nn.BatchNorm2d(self.CatChannels).to(DEVICE_1)
        self.h3_Cat_hd3_relu = nn.ReLU(inplace=True).to(DEVICE_1)

        # hd4->40*40, hd4->80*80, Upsample 2 times
        self.hd4_UT_hd3 = nn.Upsample(scale_factor=2, mode='bilinear').to(DEVICE_1)  # 14*14
        self.hd4_UT_hd3_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1).to(DEVICE_1)
        self.hd4_UT_hd3_bn = nn.BatchNorm2d(self.CatChannels).to(DEVICE_1)
        self.hd4_UT_hd3_relu = nn.ReLU(inplace=True).to(DEVICE_1)

        # hd5->20*20, hd4->80*80, Upsample 4 times
        self.hd5_UT_hd3 = nn.Upsample(scale_factor=4, mode='bilinear').to(DEVICE_1) # 14*14
        self.hd5_UT_hd3_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1).to(DEVICE_1)
        self.hd5_UT_hd3_bn = nn.BatchNorm2d(self.CatChannels).to(DEVICE_1)
        self.hd5_UT_hd3_relu = nn.ReLU(inplace=True).to(DEVICE_1)

        # fusion(h1_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3)
        self.conv3d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1).to(DEVICE_1)  # 16
        self.bn3d_1 = nn.BatchNorm2d(self.UpChannels).to(DEVICE_1)
        self.relu3d_1 = nn.ReLU(inplace=True).to(DEVICE_1)

        '''stage 2d '''
        # h1->320*320, hd2->160*160, Pooling 2 times
        self.h1_PT_hd2 = nn.MaxPool2d(2, 2, ceil_mode=True).to(DEVICE_1)
        self.h1_PT_hd2_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1).to(DEVICE_1)
        self.h1_PT_hd2_bn = nn.BatchNorm2d(self.CatChannels).to(DEVICE_1)
        self.h1_PT_hd2_relu = nn.ReLU(inplace=True).to(DEVICE_1)

        # h2->160*160, hd2->160*160, Concatenation
        self.h2_Cat_hd2_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1).to(DEVICE_1)
        self.h2_Cat_hd2_bn = nn.BatchNorm2d(self.CatChannels).to(DEVICE_1)
        self.h2_Cat_hd2_relu = nn.ReLU(inplace=True).to(DEVICE_1)

        # hd3->80*80, hd2->160*160, Upsample 2 times
        self.hd3_UT_hd2 = nn.Upsample(scale_factor=2, mode='bilinear').to(DEVICE_1)  # 14*14
        self.hd3_UT_hd2_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1).to(DEVICE_1)
        self.hd3_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels).to(DEVICE_1)
        self.hd3_UT_hd2_relu = nn.ReLU(inplace=True).to(DEVICE_1)

        # hd4->40*40, hd2->160*160, Upsample 4 times
        self.hd4_UT_hd2 = nn.Upsample(scale_factor=4, mode='bilinear').to(DEVICE_1) # 14*14
        self.hd4_UT_hd2_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1).to(DEVICE_1)
        self.hd4_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels).to(DEVICE_1)
        self.hd4_UT_hd2_relu = nn.ReLU(inplace=True).to(DEVICE_1)

        # hd5->20*20, hd2->160*160, Upsample 8 times
        self.hd5_UT_hd2 = nn.Upsample(scale_factor=8, mode='bilinear').to(DEVICE_1)  # 14*14
        self.hd5_UT_hd2_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1).to(DEVICE_1)
        self.hd5_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels).to(DEVICE_1)
        self.hd5_UT_hd2_relu = nn.ReLU(inplace=True).to(DEVICE_1)

        # fusion(h1_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, hd5_UT_hd2)
        self.conv2d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1).to(DEVICE_1)  # 16
        self.bn2d_1 = nn.BatchNorm2d(self.UpChannels).to(DEVICE_1)
        self.relu2d_1 = nn.ReLU(inplace=True).to(DEVICE_1)

        '''stage 1d'''
        # h1->320*320, hd1->320*320, Concatenation
        self.h1_Cat_hd1_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1).to(DEVICE_2)
        self.h1_Cat_hd1_bn = nn.BatchNorm2d(self.CatChannels).to(DEVICE_2)
        self.h1_Cat_hd1_relu = nn.ReLU(inplace=True).to(DEVICE_2)

        # hd2->160*160, hd1->320*320, Upsample 2 times
        self.hd2_UT_hd1 = nn.Upsample(scale_factor=2, mode='bilinear').to(DEVICE_2)  # 14*14
        self.hd2_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1).to(DEVICE_2)
        self.hd2_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels).to(DEVICE_2)
        self.hd2_UT_hd1_relu = nn.ReLU(inplace=True).to(DEVICE_2)

        # hd3->80*80, hd1->320*320, Upsample 4 times
        self.hd3_UT_hd1 = nn.Upsample(scale_factor=4, mode='bilinear').to(DEVICE_2)  # 14*14
        self.hd3_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1).to(DEVICE_2)
        self.hd3_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels).to(DEVICE_2)
        self.hd3_UT_hd1_relu = nn.ReLU(inplace=True).to(DEVICE_2)

        # hd4->40*40, hd1->320*320, Upsample 8 times
        self.hd4_UT_hd1 = nn.Upsample(scale_factor=8, mode='bilinear').to(DEVICE_2)  # 14*14
        self.hd4_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1).to(DEVICE_2)
        self.hd4_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels).to(DEVICE_2)
        self.hd4_UT_hd1_relu = nn.ReLU(inplace=True).to(DEVICE_2)

        # hd5->20*20, hd1->320*320, Upsample 16 times
        self.hd5_UT_hd1 = nn.Upsample(scale_factor=16, mode='bilinear').to(DEVICE_3)  # 14*14
        self.hd5_UT_hd1_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1).to(DEVICE_3)
        self.hd5_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels).to(DEVICE_3)
        self.hd5_UT_hd1_relu = nn.ReLU(inplace=True).to(DEVICE_3)

        # fusion(h1_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, hd5_UT_hd1)
        self.conv1d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1).to(DEVICE_3)  # 16
        self.bn1d_1 = nn.BatchNorm2d(self.UpChannels).to(DEVICE_3)
        self.relu1d_1 = nn.ReLU(inplace=True).to(DEVICE_3)

        # output
        self.outconv1 = nn.Conv2d(self.UpChannels, n_classes, 3, padding=1).to(DEVICE_0)


    def forward(self, inputs):
        ## -------------Encoder-------------
        h1 = self.conv1(inputs.to(DEVICE_0)).to(DEVICE_0)  # h1->320*320*64

        h2 = self.maxpool1(h1)
        h2 = self.conv2(h2)  # h2->160*160*128

        h3 = self.maxpool2(h2)
        h3 = self.conv3(h3)  # h3->80*80*256

        h4 = self.maxpool3(h3.to(DEVICE_1))
        h4 = self.conv4(h4)  # h4->40*40*512

        h5 = self.maxpool4(h4)
        hd5 = self.conv5(h5)  # h5->20*20*1024

        ## -------------Decoder-------------
        h1_PT_hd4 = self.h1_PT_hd4_relu(self.h1_PT_hd4_bn(self.h1_PT_hd4_conv(self.h1_PT_hd4(h1.to(DEVICE_1)))))
        h2_PT_hd4 = self.h2_PT_hd4_relu(self.h2_PT_hd4_bn(self.h2_PT_hd4_conv(self.h2_PT_hd4(h2.to(DEVICE_1)))))
        h3_PT_hd4 = self.h3_PT_hd4_relu(self.h3_PT_hd4_bn(self.h3_PT_hd4_conv(self.h3_PT_hd4(h3.to(DEVICE_1)))))
        h4_Cat_hd4 = self.h4_Cat_hd4_relu(self.h4_Cat_hd4_bn(self.h4_Cat_hd4_conv(h4)))
        hd5_UT_hd4 = self.hd5_UT_hd4_relu(self.hd5_UT_hd4_bn(self.hd5_UT_hd4_conv(self.hd5_UT_hd4(hd5))))
        hd4 = self.relu4d_1(self.bn4d_1(self.conv4d_1(torch.cat((h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4), 1)))) # hd4->40*40*UpChannels

        h1_PT_hd3 = self.h1_PT_hd3_relu(self.h1_PT_hd3_bn(self.h1_PT_hd3_conv(self.h1_PT_hd3(h1.to(DEVICE_1)))))
        h2_PT_hd3 = self.h2_PT_hd3_relu(self.h2_PT_hd3_bn(self.h2_PT_hd3_conv(self.h2_PT_hd3(h2.to(DEVICE_1)))))
        h3_Cat_hd3 = self.h3_Cat_hd3_relu(self.h3_Cat_hd3_bn(self.h3_Cat_hd3_conv(h3.to(DEVICE_1))))
        hd4_UT_hd3 = self.hd4_UT_hd3_relu(self.hd4_UT_hd3_bn(self.hd4_UT_hd3_conv(self.hd4_UT_hd3(hd4))))
        hd5_UT_hd3 = self.hd5_UT_hd3_relu(self.hd5_UT_hd3_bn(self.hd5_UT_hd3_conv(self.hd5_UT_hd3(hd5))))
        hd3 = self.relu3d_1(self.bn3d_1(self.conv3d_1(torch.cat((h1_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3), 1)))) # hd3->80*80*UpChannels

        h1_PT_hd2 = self.h1_PT_hd2_relu(self.h1_PT_hd2_bn(self.h1_PT_hd2_conv(self.h1_PT_hd2(h1.to(DEVICE_1)))))
        h2_Cat_hd2 = self.h2_Cat_hd2_relu(self.h2_Cat_hd2_bn(self.h2_Cat_hd2_conv(h2.to(DEVICE_1))))
        hd3_UT_hd2 = self.hd3_UT_hd2_relu(self.hd3_UT_hd2_bn(self.hd3_UT_hd2_conv(self.hd3_UT_hd2(hd3.to(DEVICE_1)))))
        hd4_UT_hd2 = self.hd4_UT_hd2_relu(self.hd4_UT_hd2_bn(self.hd4_UT_hd2_conv(self.hd4_UT_hd2(hd4))))
        hd5_UT_hd2 = self.hd5_UT_hd2_relu(self.hd5_UT_hd2_bn(self.hd5_UT_hd2_conv(self.hd5_UT_hd2(hd5))))
        hd2 = self.relu2d_1(self.bn2d_1(self.conv2d_1(torch.cat((h1_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, hd5_UT_hd2), 1)))) # hd2->160*160*UpChannels

        h1_Cat_hd1 = self.h1_Cat_hd1_relu(self.h1_Cat_hd1_bn(self.h1_Cat_hd1_conv(h1.to(DEVICE_2))))
        hd2_UT_hd1 = self.hd2_UT_hd1_relu(self.hd2_UT_hd1_bn(self.hd2_UT_hd1_conv(self.hd2_UT_hd1(hd2.to(DEVICE_2)))))
        hd3_UT_hd1 = self.hd3_UT_hd1_relu(self.hd3_UT_hd1_bn(self.hd3_UT_hd1_conv(self.hd3_UT_hd1(hd3.to(DEVICE_2)))))
        hd4_UT_hd1 = self.hd4_UT_hd1_relu(self.hd4_UT_hd1_bn(self.hd4_UT_hd1_conv(self.hd4_UT_hd1(hd4.to(DEVICE_2)))))
        hd5_UT_hd1 = self.hd5_UT_hd1_relu(self.hd5_UT_hd1_bn(self.hd5_UT_hd1_conv(self.hd5_UT_hd1(hd5.to(DEVICE_3)))))
        hd1 = self.relu1d_1(self.bn1d_1(self.conv1d_1(torch.cat((h1_Cat_hd1.to(DEVICE_3), hd2_UT_hd1.to(DEVICE_3), hd3_UT_hd1.to(DEVICE_3), hd4_UT_hd1.to(DEVICE_3), hd5_UT_hd1.to(DEVICE_3)), 1)))) # hd1->320*320*UpChannels

        d1 = self.outconv1(hd1.to(DEVICE_0))  # d1->320*320*n_classes
        
        #if (self.n_classes == 1):
        #    outpue = torch.sigmoid(d1)
        #else:
        #    output = torch.softmax(d1,dim=1)
        return d1 #output 
RuntimeError: Error(s) in loading state_dict for ModelParallelUnet:
        Missing key(s) in state_dict: "ups.0.weight", "ups.0.bias", "ups.1.conv.0.weight", "ups.1.conv.1.weight", "ups.1.conv.1.bias", "ups.1.conv.1.running_mean", "ups.1.conv.1.running_var", "ups.1.conv.3.weight", "ups.1.conv.4.weight", "ups.1.conv.4.bias", "ups.1.conv.4.running_mean", "ups.1.conv.4.running_var", "ups.2.weight", "ups.2.bias", "ups.3.conv.0.weight", "ups.3.conv.1.weight", "ups.3.conv.1.bias", "ups.3.conv.1.running_mean", "ups.3.conv.1.running_var", "ups.3.conv.3.weight", "ups.3.conv.4.weight", "ups.3.conv.4.bias", "ups.3.conv.4.running_mean", "ups.3.conv.4.running_var", "ups.4.weight", "ups.4.bias", "ups.5.conv.0.weight", "ups.5.conv.1.weight", "ups.5.conv.1.bias", "ups.5.conv.1.running_mean", "ups.5.conv.1.running_var", "ups.5.conv.3.weight", "ups.5.conv.4.weight", "ups.5.conv.4.bias", "ups.5.conv.4.running_mean", "ups.5.conv.4.running_var", "ups.6.weight", "ups.6.bias", "ups.7.conv.0.weight", "ups.7.conv.1.weight", "ups.7.conv.1.bias", "ups.7.conv.1.running_mean", "ups.7.conv.1.running_var", "ups.7.conv.3.weight", "ups.7.conv.4.weight", "ups.7.conv.4.bias", "ups.7.conv.4.running_mean", "ups.7.conv.4.running_var", "ups.8.weight", "ups.8.bias", "ups.9.conv.0.weight", "ups.9.conv.1.weight", "ups.9.conv.1.bias", "ups.9.conv.1.running_mean", "ups.9.conv.1.running_var", "ups.9.conv.3.weight", "ups.9.conv.4.weight", "ups.9.conv.4.bias", "ups.9.conv.4.running_mean", "ups.9.conv.4.running_var", "downs.0.conv.0.weight", "downs.0.conv.1.weight", "downs.0.conv.1.bias", "downs.0.conv.1.running_mean", "downs.0.conv.1.running_var", "downs.0.conv.3.weight", "downs.0.conv.4.weight", "downs.0.conv.4.bias", "downs.0.conv.4.running_mean", "downs.0.conv.4.running_var", "downs.1.conv.0.weight", "downs.1.conv.1.weight", "downs.1.conv.1.bias", "downs.1.conv.1.running_mean", "downs.1.conv.1.running_var", "downs.1.conv.3.weight", "downs.1.conv.4.weight", "downs.1.conv.4.bias", "downs.1.conv.4.running_mean", "downs.1.conv.4.running_var", "downs.2.conv.0.weight", "downs.2.conv.1.weight", "downs.2.conv.1.bias", "downs.2.conv.1.running_mean", "downs.2.conv.1.running_var", "downs.2.conv.3.weight", "downs.2.conv.4.weight", "downs.2.conv.4.bias", "downs.2.conv.4.running_mean", "downs.2.conv.4.running_var", "downs.3.conv.0.weight", "downs.3.conv.1.weight", "downs.3.conv.1.bias", "downs.3.conv.1.running_mean", "downs.3.conv.1.running_var", "downs.3.conv.3.weight", "downs.3.conv.4.weight", "downs.3.conv.4.bias", "downs.3.conv.4.running_mean", "downs.3.conv.4.running_var", "downs.4.conv.0.weight", "downs.4.conv.1.weight", "downs.4.conv.1.bias", "downs.4.conv.1.running_mean", "downs.4.conv.1.running_var", "downs.4.conv.3.weight", "downs.4.conv.4.weight", "downs.4.conv.4.bias", "downs.4.conv.4.running_mean", "downs.4.conv.4.running_var", "bottleneck.conv.0.weight", "bottleneck.conv.1.weight", "bottleneck.conv.1.bias", "bottleneck.conv.1.running_mean", "bottleneck.conv.1.running_var", "bottleneck.conv.3.weight", "bottleneck.conv.4.weight", "bottleneck.conv.4.bias", "bottleneck.conv.4.running_mean", "bottleneck.conv.4.running_var", "final_conv.weight", "final_conv.bias".
        Unexpected key(s) in state_dict: "model_state_dict", "optim_state_dict", "epoch", "loss_values", "accuracy", "epochs_run", "epoch_time".

Can you try to get info from how GPU memory is allocated by using

torch.cuda.memory_summary(device=None, abbreviated=False)

Instead of manually move the input to each device. Have you considered using other parallization techiniques such as FSDP or Pippy?

Doc for FDSP: FullyShardedDataParallel — PyTorch 2.0 documentation

Doc for PiPPy: GitHub - pytorch/PiPPy: Pipeline Parallelism for PyTorch