Hello there. Recently I am trying to create a network to process high-resolution images, thus I decide to crop input images to smaller ones and then concatenate corresponding feature maps following the spatial order of the cropped ones. Intuitively such operation could reduce video memory usage, but in my experiments I found that it seems not to work… Here is my code and could anyone help me to figure it out?

Original Network:

```
class DarkNet53(nn.Module):
def __init__(self):
super().__init__()
self.module_list = build_backbone_modules()
def forward(self, x):
route_layers = []
for i, module in enumerate(self.module_list):
# yolo layers
x = module(x)
# route layers
if i in [6, 8, 17, 24, 32]:
route_layers.append(x)
if i == 19:
x = torch.cat((x, route_layers[1]), 1)
if i == 26:
x = torch.cat((x, route_layers[0]), 1)
return route_layers, x
```

Refined Network:

```
class DarkNet53_LargeScale(nn.Module):
def __init__(self, crop_idx=4):
super().__init__()
self.darknet = DarkNet53()
self.crop_idx = crop_idx
def get_local_input(self, x):
# get cropped tensors
# crop order: left 2 right, up 2 down
size = x.shape[2:]
crop_tensor_list = []
h_index_list = [size[0] // self.crop_idx * i for i in range(self.crop_idx + 1)]
w_index_list = [size[1] // self.crop_idx * i for i in range(self.crop_idx + 1)]
for i in range(self.crop_idx):
for j in range(self.crop_idx):
crop_tensor = x[..., h_index_list[i]:h_index_list[i + 1], w_index_list[j]:w_index_list[j + 1]]
crop_tensor_list.append(crop_tensor)
return crop_tensor_list, h_index_list, w_index_list
def forward_backbone(self, crop_tensor_list):
# forward each cropped tensor and get the output
route_layers_list = []
x_list = []
for tensor in crop_tensor_list:
route_layers, x = self.darknet(tensor)
route_layers_list.append(route_layers)
x_list.append(x)
return route_layers_list, x_list
def forward_post_process(self, route_layers_list, x_list):
# concatenate each cropped tensor's output following their spatial order to reconstruct feature maps
n = x_list[0].shape[0] # batch size
c_x, h_x, w_x = x_list[0].shape[1:]
c_0, h_0, w_0 = route_layers_list[0][0].shape[1:] # num of channels, height, width
c_1, h_1, w_1 = route_layers_list[0][1].shape[1:]
c_2, h_2, w_2 = route_layers_list[0][2].shape[1:]
c_3, h_3, w_3 = route_layers_list[0][3].shape[1:]
c_4, h_4, w_4 = route_layers_list[0][4].shape[1:]
# route_layers_list contains (idx*idx) route_layers，every route_layers contains 5 tensors
# x_list contains (idx*idx) tensors
x_out = torch.zeros((n, c_x, h_x * self.crop_idx, w_x * self.crop_idx))
route_layers_out_0 = torch.zeros((n, c_0, h_0 * self.crop_idx, w_0 * self.crop_idx))
route_layers_out_1 = torch.zeros((n, c_1, h_1 * self.crop_idx, w_1 * self.crop_idx))
route_layers_out_2 = torch.zeros((n, c_2, h_2 * self.crop_idx, w_2 * self.crop_idx))
route_layers_out_3 = torch.zeros((n, c_3, h_3 * self.crop_idx, w_3 * self.crop_idx))
route_layers_out_4 = torch.zeros((n, c_4, h_4 * self.crop_idx, w_4 * self.crop_idx))
for i, x in enumerate(x_list):
h_index = int(i / self.crop_idx)
w_index = i % self.crop_idx
x_out[..., (h_index * h_x):((h_index + 1) * h_x), (w_index * w_x):((w_index + 1) * w_x)] = x
route_layers_out_0[..., (h_index * h_0):((h_index + 1) * h_0), (w_index * w_0):((w_index + 1) * w_0)] = \
route_layers_list[i][0]
route_layers_out_1[..., (h_index * h_1):((h_index + 1) * h_1), (w_index * w_1):((w_index + 1) * w_1)] = \
route_layers_list[i][1]
route_layers_out_2[..., (h_index * h_2):((h_index + 1) * h_2), (w_index * w_2):((w_index + 1) * w_2)] = \
route_layers_list[i][2]
route_layers_out_3[..., (h_index * h_3):((h_index + 1) * h_3), (w_index * w_3):((w_index + 1) * w_3)] = \
route_layers_list[i][3]
route_layers_out_4[..., (h_index * h_4):((h_index + 1) * h_4), (w_index * w_4):((w_index + 1) * w_4)] = \
route_layers_list[i][4]
return [route_layers_out_0, route_layers_out_1, route_layers_out_2, route_layers_out_3, route_layers_out_4], x_out
def forward_original(self, x):
route_layers = []
for i, module in enumerate(self.module_list):
# yolo layers
x = module(x)
# route layers
if i in [6, 8, 17, 24, 32]:
route_layers.append(x)
if i == 19:
x = torch.cat((x, route_layers[1]), 1)
if i == 26:
x = torch.cat((x, route_layers[0]), 1)
return route_layers, x
def forward(self, x):
crop_tensor_list, h_index_list, w_index_list = self.get_local_input(x)
route_layers_list, x_list = self.forward_backbone(crop_tensor_list)
route_layers, x = self.forward_post_process(route_layers_list, x_list)
return route_layers, x
```

And here is the code for building the net in case you may need it:

```
def build_backbone_modules():
"""
Build yolov3 layer modules.
Args:
ignore_thre (float): used in YOLOLayer.
Returns:
mlist (ModuleList): YOLOv3 module list.
"""
# DarkNet53
mlist = nn.ModuleList()
mlist.append(add_conv(in_ch=3, out_ch=32, ksize=3, stride=1)) # 0
mlist.append(add_conv(in_ch=32, out_ch=64, ksize=3, stride=2)) # 1
mlist.append(resblock(ch=64)) # 2
mlist.append(add_conv(in_ch=64, out_ch=128, ksize=3, stride=2)) # 3
mlist.append(resblock(ch=128, nblocks=2)) # 4
mlist.append(add_conv(in_ch=128, out_ch=256, ksize=3, stride=2)) # 5
mlist.append(resblock(ch=256, nblocks=8)) # shortcut 1 from here #6
mlist.append(add_conv(in_ch=256, out_ch=512, ksize=3, stride=2)) # 7
mlist.append(resblock(ch=512, nblocks=8)) # shortcut 2 from here #8
mlist.append(add_conv(in_ch=512, out_ch=1024, ksize=3, stride=2)) # 9
mlist.append(resblock(ch=1024, nblocks=4)) # 10
# YOLOv3
mlist.append(resblock(ch=1024, nblocks=1, shortcut=False)) # 11
mlist.append(add_conv(in_ch=1024, out_ch=512, ksize=1, stride=1)) # 12
# SPP Layer
mlist.append(SPPLayer()) # 13
mlist.append(add_conv(in_ch=2048, out_ch=512, ksize=1, stride=1)) # 14
mlist.append(add_conv(in_ch=512, out_ch=1024, ksize=3, stride=1)) # 15
mlist.append(DropBlock(block_size=1, keep_prob=1)) # 16
mlist.append(add_conv(in_ch=1024, out_ch=512, ksize=1, stride=1)) # 17
# 1st yolo branch
mlist.append(add_conv(in_ch=512, out_ch=256, ksize=1, stride=1)) # 18
mlist.append(upsample(scale_factor=2, mode='nearest')) # 19
mlist.append(add_conv(in_ch=768, out_ch=256, ksize=1, stride=1)) # 20
mlist.append(add_conv(in_ch=256, out_ch=512, ksize=3, stride=1)) # 21
mlist.append(DropBlock(block_size=1, keep_prob=1)) # 22
mlist.append(resblock(ch=512, nblocks=1, shortcut=False)) # 23
mlist.append(add_conv(in_ch=512, out_ch=256, ksize=1, stride=1)) # 24
# 2nd yolo branch
mlist.append(add_conv(in_ch=256, out_ch=128, ksize=1, stride=1)) # 25
mlist.append(upsample(scale_factor=2, mode='nearest')) # 26
mlist.append(add_conv(in_ch=384, out_ch=128, ksize=1, stride=1)) # 27
mlist.append(add_conv(in_ch=128, out_ch=256, ksize=3, stride=1)) # 28
mlist.append(DropBlock(block_size=1, keep_prob=1)) # 29
mlist.append(resblock(ch=256, nblocks=1, shortcut=False)) # 30
mlist.append(add_conv(in_ch=256, out_ch=128, ksize=1, stride=1)) # 31
mlist.append(add_conv(in_ch=128, out_ch=256, ksize=3, stride=1)) # 32
return mlist
```

Could anyone help me to figure it out?