How to save video memory usage?

Hello there. Recently I am trying to create a network to process high-resolution images, thus I decide to crop input images to smaller ones and then concatenate corresponding feature maps following the spatial order of the cropped ones. Intuitively such operation could reduce video memory usage, but in my experiments I found that it seems not to work… Here is my code and could anyone help me to figure it out?
Original Network:

class DarkNet53(nn.Module):
    def __init__(self):
        super().__init__()
        self.module_list = build_backbone_modules()

    def forward(self, x):
        route_layers = []
        for i, module in enumerate(self.module_list):

            # yolo layers
            x = module(x)

            # route layers
            if i in [6, 8, 17, 24, 32]:
                route_layers.append(x)
            if i == 19:
                x = torch.cat((x, route_layers[1]), 1)
            if i == 26:
                x = torch.cat((x, route_layers[0]), 1)
        return route_layers, x

Refined Network:

class DarkNet53_LargeScale(nn.Module):
    def __init__(self, crop_idx=4):
        super().__init__()
        self.darknet = DarkNet53()
        self.crop_idx = crop_idx

    def get_local_input(self, x):
        # get cropped tensors
        # crop order: left 2 right, up 2 down
        size = x.shape[2:]
        crop_tensor_list = []
        h_index_list = [size[0] // self.crop_idx * i for i in range(self.crop_idx + 1)]
        w_index_list = [size[1] // self.crop_idx * i for i in range(self.crop_idx + 1)]
        for i in range(self.crop_idx):
            for j in range(self.crop_idx):
                crop_tensor = x[..., h_index_list[i]:h_index_list[i + 1], w_index_list[j]:w_index_list[j + 1]]
                crop_tensor_list.append(crop_tensor)
        return crop_tensor_list, h_index_list, w_index_list

    def forward_backbone(self, crop_tensor_list):
        # forward each cropped tensor and get the output
        route_layers_list = []
        x_list = []
        for tensor in crop_tensor_list:
            route_layers, x = self.darknet(tensor)
            route_layers_list.append(route_layers)
            x_list.append(x)
        return route_layers_list, x_list

    def forward_post_process(self, route_layers_list, x_list):
        # concatenate each cropped tensor's output following their spatial order to reconstruct feature maps
        n = x_list[0].shape[0]  # batch size
        c_x, h_x, w_x = x_list[0].shape[1:]
        c_0, h_0, w_0 = route_layers_list[0][0].shape[1:]  # num of channels, height, width
        c_1, h_1, w_1 = route_layers_list[0][1].shape[1:]
        c_2, h_2, w_2 = route_layers_list[0][2].shape[1:]
        c_3, h_3, w_3 = route_layers_list[0][3].shape[1:]
        c_4, h_4, w_4 = route_layers_list[0][4].shape[1:]
        # route_layers_list contains (idx*idx) route_layers,every route_layers contains 5 tensors
        # x_list contains (idx*idx) tensors
        x_out = torch.zeros((n, c_x, h_x * self.crop_idx, w_x * self.crop_idx))
        route_layers_out_0 = torch.zeros((n, c_0, h_0 * self.crop_idx, w_0 * self.crop_idx))
        route_layers_out_1 = torch.zeros((n, c_1, h_1 * self.crop_idx, w_1 * self.crop_idx))
        route_layers_out_2 = torch.zeros((n, c_2, h_2 * self.crop_idx, w_2 * self.crop_idx))
        route_layers_out_3 = torch.zeros((n, c_3, h_3 * self.crop_idx, w_3 * self.crop_idx))
        route_layers_out_4 = torch.zeros((n, c_4, h_4 * self.crop_idx, w_4 * self.crop_idx))

        for i, x in enumerate(x_list):
            h_index = int(i / self.crop_idx)
            w_index = i % self.crop_idx
            x_out[..., (h_index * h_x):((h_index + 1) * h_x), (w_index * w_x):((w_index + 1) * w_x)] = x
            route_layers_out_0[..., (h_index * h_0):((h_index + 1) * h_0), (w_index * w_0):((w_index + 1) * w_0)] = \
                route_layers_list[i][0]
            route_layers_out_1[..., (h_index * h_1):((h_index + 1) * h_1), (w_index * w_1):((w_index + 1) * w_1)] = \
                route_layers_list[i][1]
            route_layers_out_2[..., (h_index * h_2):((h_index + 1) * h_2), (w_index * w_2):((w_index + 1) * w_2)] = \
                route_layers_list[i][2]
            route_layers_out_3[..., (h_index * h_3):((h_index + 1) * h_3), (w_index * w_3):((w_index + 1) * w_3)] = \
                route_layers_list[i][3]
            route_layers_out_4[..., (h_index * h_4):((h_index + 1) * h_4), (w_index * w_4):((w_index + 1) * w_4)] = \
                route_layers_list[i][4]
        return [route_layers_out_0, route_layers_out_1, route_layers_out_2, route_layers_out_3, route_layers_out_4], x_out

    def forward_original(self, x):
        route_layers = []
        for i, module in enumerate(self.module_list):

            # yolo layers
            x = module(x)

            # route layers
            if i in [6, 8, 17, 24, 32]:
                route_layers.append(x)
            if i == 19:
                x = torch.cat((x, route_layers[1]), 1)
            if i == 26:
                x = torch.cat((x, route_layers[0]), 1)

        return route_layers, x

    def forward(self, x):
        crop_tensor_list, h_index_list, w_index_list = self.get_local_input(x)
        route_layers_list, x_list = self.forward_backbone(crop_tensor_list)
        route_layers, x = self.forward_post_process(route_layers_list, x_list)
        return route_layers, x

And here is the code for building the net in case you may need it:

def build_backbone_modules():
    """
    Build yolov3 layer modules.
    Args:
        ignore_thre (float): used in YOLOLayer.
    Returns:
        mlist (ModuleList): YOLOv3 module list.
    """
    # DarkNet53
    mlist = nn.ModuleList()
    mlist.append(add_conv(in_ch=3, out_ch=32, ksize=3, stride=1))  # 0
    mlist.append(add_conv(in_ch=32, out_ch=64, ksize=3, stride=2))  # 1
    mlist.append(resblock(ch=64))  # 2
    mlist.append(add_conv(in_ch=64, out_ch=128, ksize=3, stride=2))  # 3
    mlist.append(resblock(ch=128, nblocks=2))  # 4
    mlist.append(add_conv(in_ch=128, out_ch=256, ksize=3, stride=2))  # 5
    mlist.append(resblock(ch=256, nblocks=8))  # shortcut 1 from here     #6
    mlist.append(add_conv(in_ch=256, out_ch=512, ksize=3, stride=2))  # 7
    mlist.append(resblock(ch=512, nblocks=8))  # shortcut 2 from here     #8
    mlist.append(add_conv(in_ch=512, out_ch=1024, ksize=3, stride=2))  # 9
    mlist.append(resblock(ch=1024, nblocks=4))  # 10

    # YOLOv3
    mlist.append(resblock(ch=1024, nblocks=1, shortcut=False))  # 11
    mlist.append(add_conv(in_ch=1024, out_ch=512, ksize=1, stride=1))  # 12
    # SPP Layer
    mlist.append(SPPLayer())  # 13

    mlist.append(add_conv(in_ch=2048, out_ch=512, ksize=1, stride=1))  # 14
    mlist.append(add_conv(in_ch=512, out_ch=1024, ksize=3, stride=1))  # 15
    mlist.append(DropBlock(block_size=1, keep_prob=1))  # 16
    mlist.append(add_conv(in_ch=1024, out_ch=512, ksize=1, stride=1))  # 17

    # 1st yolo branch
    mlist.append(add_conv(in_ch=512, out_ch=256, ksize=1, stride=1))  # 18
    mlist.append(upsample(scale_factor=2, mode='nearest'))  # 19
    mlist.append(add_conv(in_ch=768, out_ch=256, ksize=1, stride=1))  # 20
    mlist.append(add_conv(in_ch=256, out_ch=512, ksize=3, stride=1))  # 21
    mlist.append(DropBlock(block_size=1, keep_prob=1))  # 22
    mlist.append(resblock(ch=512, nblocks=1, shortcut=False))  # 23
    mlist.append(add_conv(in_ch=512, out_ch=256, ksize=1, stride=1))  # 24
    # 2nd yolo branch

    mlist.append(add_conv(in_ch=256, out_ch=128, ksize=1, stride=1))  # 25
    mlist.append(upsample(scale_factor=2, mode='nearest'))  # 26
    mlist.append(add_conv(in_ch=384, out_ch=128, ksize=1, stride=1))  # 27
    mlist.append(add_conv(in_ch=128, out_ch=256, ksize=3, stride=1))  # 28
    mlist.append(DropBlock(block_size=1, keep_prob=1))  # 29
    mlist.append(resblock(ch=256, nblocks=1, shortcut=False))  # 30
    mlist.append(add_conv(in_ch=256, out_ch=128, ksize=1, stride=1))  # 31
    mlist.append(add_conv(in_ch=128, out_ch=256, ksize=3, stride=1))  # 32

    return mlist

Could anyone help me to figure it out?

I’m unsure why this approach should save memory, since it seems you are appending all outputs and I guess you are then stacking them to the original output and try to calculate the gradients?
If so, then I think the output tensors as well as all intermediate activation tensors would still be calculated the there should not be any memory saving.
If you want to lower the memory usage, you could try to calculate the gradients for each crop, and accumulate the gradients for all patches of the original input.
After each backward call, the intermediate activations would be freed and you might be able to reduce the memory usage using this approach.

Well, the problem is that the network I build above is the backbone of a detection model, so the concatenated feature maps will be used for downstream tasks (detection, segmentation, etc.), before which the loss cannot be calculated and thus I think I am not able to get the gredient and accumulate them. The desire behavior of the backbone is that when I am forwarding the cropped patches of one original input, the backbone will share the parameters (that is to say, no matter which patch the network is processing, the backbone uses the same parameters, just a bit kind of like a RNN). Under such behavior shall I reduce memory usage? And I wonder what kind of changes I should make to the code to achieve such behavior?

Sorry to bother again, but I wonder is there any method to reduce intermediate variants without affecting backward propagation?

I’m not sure, what “reduce intermediate variants” means, but you could generally save memory by lowering the batch size or by e.g. using torch.utils.checkpoint.
Based on your previous description, I don’t think you can save memory by splitting the input into different patches.
I might misunderstand the explanation still, so you could double check this claim:
check which intermediate tensors and gradients would be used using the “patch approach” and compare it to the use case where the complete input is passed.
If I’m not mistaken, both should yield the same memory usage, since you need the entire output for the downstream task.

Perhaps I am not making myself very clear, I will try my best to make it clear. The whole point is that, by cropping input image into patches, is there any possibility to reduce memory usage during the training?
As you have explained above, there is NO memory usage reduction because I have to concatenate the feature maps for downstream tasks, thus whether passing a whole image or passing patches will produce similar intermediate variants (intermediate activation tensors, etc.) for backward propagation, thus no memory usage reduction.
But, (if I understand it correctly), if every time I perform the backward process (calculate the loss and get the gradient), the intermediate activation tensors will be freed, thus reduing memory usage. But this approach might not work in my case as there is no suitable loss function to get the loss of my backbone’s output feature maps.
Thus, another idea of mine is that, what if I modify the backbone into a RNN? As the backbone actually takes the patches one by one, a serial input, I think it is really like a RNN. So, I am curious that, if I modify my network to make it become a RNN, could such modification reduce memory usage?

I don’t think so, as it would create the same computation graph in the end but in a “sequential” way instead of multiple forward “parallel” graphs. Autograd would thus backprop through all “steps” and use the same amount of memory.
Again, I would recommend to check the needed activations for your ideas and compare it to the baseline of using the complete image.

Thanks for your advice! I am wondering that is there any tools, functions, etc. provided by PyTorch to help us check the intermediate activation tensors for backward propagation?