@ptrblck I followed your post to get the following.
However, because of the loops, the solution is slow when we have multiple groups. How should I get the 5D output without the loops for the case when the mask_convolution
function itself has groups > 1
def mask_convolution(input, kernel_wt, padding= 0, stride= 1, \
groups= 1, mask= None, mask_patches= False, debug= False):
output = cast_to_cpu_cuda_tensor(torch.zeros((batch, out_channels, Hout, Wout)), reference_tensor= input)
index_input = torch.arange(groups+1)*(in_channels//groups)
index_output = torch.arange(groups+1)*(out_channels//groups)
for i in range(groups):
inp_group = input[:, index_input[i]:index_input[i+1]]
norm_kernel_wt_interm = norm_kernel_wt[index_output[i]:index_output[i+1]]\
.reshape(out_channels_grp * in_channels_grp, 1, K, K)
# B x Cout*Cin x Hout x Wout
output_grp = F.conv2d(inp_group, norm_kernel_wt_interm, padding= padding, stride= stride, groups= in_channels_grp)
output_grp = output_grp.view(batch, in_channels_grp, out_channels_grp, Hout, Wout)\
.permute(0, 2, 1, 3, 4)\
.reshape(batch, out_channels_grp, in_channels_grp, Hout, Wout)
if mask_patches:
output_final = output_grp * (mask.unsqueeze(1))
else:
output_final = output_grp
output_final = torch.sum(output_final, dim= 2) # B x Cout x Hout x Wout
output[:, index_output[i]:index_output[i+1]] = output_div_by_energy
return output