Masking the intermediate 5D Conv2D output

@ptrblck I followed your post to get the following.

However, because of the loops, the solution is slow when we have multiple groups. How should I get the 5D output without the loops for the case when the mask_convolution function itself has groups > 1

def mask_convolution(input, kernel_wt, padding= 0, stride= 1, \
        groups= 1, mask= None, mask_patches= False, debug= False):

    output       = cast_to_cpu_cuda_tensor(torch.zeros((batch, out_channels, Hout, Wout)), reference_tensor= input)
    index_input  = torch.arange(groups+1)*(in_channels//groups)
    index_output = torch.arange(groups+1)*(out_channels//groups)

    for i in range(groups):
        inp_group             = input[:, index_input[i]:index_input[i+1]]
        norm_kernel_wt_interm = norm_kernel_wt[index_output[i]:index_output[i+1]]\
                                        .reshape(out_channels_grp * in_channels_grp, 1, K, K)
        
        # B x Cout*Cin x Hout x Wout
        output_grp            = F.conv2d(inp_group, norm_kernel_wt_interm, padding= padding, stride= stride, groups= in_channels_grp)
        output_grp            = output_grp.view(batch, in_channels_grp, out_channels_grp, Hout, Wout)\
                                        .permute(0, 2, 1, 3, 4)\
                                        .reshape(batch, out_channels_grp, in_channels_grp, Hout, Wout)
        if mask_patches:
             output_final  = output_grp * (mask.unsqueeze(1))
        else:
             output_final  = output_grp

        output_final = torch.sum(output_final, dim= 2) # B x Cout x Hout x Wout

        output[:, index_output[i]:index_output[i+1]] = output_div_by_energy

    return output