group = 4 # input = (batch_size, in_channel, width, height) # weight = (out_channel, in_channel, kernel_size, kernel_size) input_unfold = torch.nn.functional.unfold(input, (weight.size(2), weight.size(3)), padding=1, stride=1) input_unfold_split = torch.split(input_unfold, group*weight.size(2)*weight.size(3), 1) input_group = torch.stack(input_unfold_split).transpose(2, 3) weight_view = weight.view(weight.size(0),-1).t() weight_view_split = torch.split(weight_view, group*weight.size(2)*weight.size(3), 0) weight_view_stack = torch.stack(weight_view_split) weight_group = torch.unsqueeze(weight_view_stack, 1) output = input_group.matmul(weight_group).transpose(2,3) # output = (in_channel / group, batch_size, out_channel, width*height)
I want to do the convolution with four-channel(group) summation, so the output will bigger than normal conv2d (in_channel / group) times.
Anyone can facilitate or optimize my code with the same functionality?
I cannot improve it for a long time, but I need to facilitate it.
Also, the output can be reshaped to (in_channel / group, batch_size, out_channel, width, height), this result is also can be accepted.