Masking the intermediate 5D Conv2D output

Hi PyTorch Team,
I have an input tensor of shape (B, C_in, H, W), conv (C_in, C_out, K, K, Pad= K//2) and a mask of shape (B, C_in, H, W) . I wanted to multiply the intermediate 5D output with my mask before carrying out the final summation.

The pictures of the usual and convolution with masking is as follows:
The usual convolution operates as follows:

However, I want to mask out the intermediate 5D output with the mask matrix before the summation across the channels as shown in the following updated figure

Figures’ Courtesy:

I tried getting the intermediate 5D tensor first. However, I am not able to extract the intermediate 5D tensor of shape (B, C_out, C_in, H, W). Following are the things I tried:

in_channels = 5
conv = nn.Conv2d(in_channels, 20, 3, groups= in_channels, padding= 1)
x = torch.randn(1, in_channels, 64, 64)
output = conv(x)
# torch.Size([1, 20, 64, 64])
  • I also tried unfolding it
inp = torch.randn(1, in_channels, 10, 12)
w   = torch.randn(2, in_channels, 3, 3)
inp_unf = torch.nn.functional.unfold(inp, (3, 3))
out_unf = inp_unf.transpose(1, 2).matmul(w.view(w.size(0), -1).t()).transpose(1, 2)
out = out_unf.view(1, 2, 8, 10)
# torch.Size([1, 2, 80])

Since I could not find any working solution on the forums, I decided to post here. It would be great if you could answer this question. In case it is a duplicate post, would you mind pointing me to the correct answer.

This post of using the grouped conv approach and reducing the output manually afterwards might be helpful.

1 Like

@ptrblck Thank you for the pointer. I implemented the masked convolution using the unfold function as shown below. However, as you might have guessed, if the convolution has group argument other than 1, this code would not work.

I had a related question. Can I extend the following code based on unfold for conv2d with group arguments as well?

def mask_convolution(input, kernel_wt, padding= 0, stride= 1, mask= None, mask_patches= False, debug= False):
    if debug:
        print("\nInside function ...")

    # kernel_wt with shape Cout x Cin x K x K
    out_channels, in_channels, K, K = kernel_wt.shape
    batch, _, H, W                  = input.shape
    Hout = (H + 2*padding - 2*(K//2))//stride
    Wout = (W + 2*padding - 2*(K//2))//stride

    norm_kernel_wt        = kernel_wt

    inp_unf               = torch.nn.functional.unfold(inp, (K, K), padding= padding, stride= stride)  # B x Cin*K*K X Hout*Wout
    inp_unf_interm        = inp_unf.transpose(1, 2).unsqueeze(1)  # B x 1 x Hout*Wout x Cin*K*K

    norm_kernel_wt_interm = norm_kernel_wt.view(out_channels, -1).unsqueeze(0).unsqueeze(2) # 1 x Cout x 1 x Cin*K*K

    out_unf = torch.mul(inp_unf_interm, norm_kernel_wt_interm).reshape(batch, out_channels, inp_unf_interm.shape[2], in_channels, -1) # B x Cout x Hout*Wout x Cin x K*K
    output  = torch.sum(out_unf, dim= 4).transpose(2, 3).reshape(batch, out_channels, in_channels, Hout, Wout) # B x Cout x Cin x Hout x Wout
    if debug:
        print("out_unf.shape before sum", out_unf.shape)
        print("output.shape before sum", output.shape)

    if mask_patches:
        output_final  = output * (mask.unsqueeze(1))
        output_final  = output

    output_final = torch.sum(output_final, dim= 2) # B x Cout x Hout x Wout
    if debug:
        print("output_final shape", output_final.shape)

    return output_final

The idea of the linked post is to use the grouped convolution in order to avoid the reduction in the in_channels dimension so that you can manipulate the activations manually before applying the sum.
But yes, unfold should also work.

1 Like

@ptrblck I followed your post to get the following.

However, because of the loops, the solution is slow when we have multiple groups. How should I get the 5D output without the loops for the case when the mask_convolution function itself has groups > 1

def mask_convolution(input, kernel_wt, padding= 0, stride= 1, \
        groups= 1, mask= None, mask_patches= False, debug= False):

    output       = cast_to_cpu_cuda_tensor(torch.zeros((batch, out_channels, Hout, Wout)), reference_tensor= input)
    index_input  = torch.arange(groups+1)*(in_channels//groups)
    index_output = torch.arange(groups+1)*(out_channels//groups)

    for i in range(groups):
        inp_group             = input[:, index_input[i]:index_input[i+1]]
        norm_kernel_wt_interm = norm_kernel_wt[index_output[i]:index_output[i+1]]\
                                        .reshape(out_channels_grp * in_channels_grp, 1, K, K)
        # B x Cout*Cin x Hout x Wout
        output_grp            = F.conv2d(inp_group, norm_kernel_wt_interm, padding= padding, stride= stride, groups= in_channels_grp)
        output_grp            = output_grp.view(batch, in_channels_grp, out_channels_grp, Hout, Wout)\
                                        .permute(0, 2, 1, 3, 4)\
                                        .reshape(batch, out_channels_grp, in_channels_grp, Hout, Wout)
        if mask_patches:
             output_final  = output_grp * (mask.unsqueeze(1))
             output_final  = output_grp

        output_final = torch.sum(output_final, dim= 2) # B x Cout x Hout x Wout

        output[:, index_output[i]:index_output[i+1]] = output_div_by_energy

    return output

I’m not sure why the groups loop is needed as my code snippet doesn’t use it.

Thank you for your reply.

I wanted to get the intermediate 5D output for the original convolution with groups>1.

e.g. Let us use an example similar to your post. Consider an input activation in the shape [batch_size=1, channels=6, height=24, width=24] and a standard conv layer with a weight tensor in the shape [out_channels=6, in_channels=2, height=3, width=3]

We want to get the intermediate 5D outputs of the result of F.conv2d(input, weight, groups=3, stride=1, bias= None, padding= 1) . So, I kept the group loop to loop over the three slices of input and output in accordance with your figure. With mask = identity, the output produced by my function is same as the F.conv2d(input, weight,groups=3, stride=1). However, my function is orders of magnitude slower to run because of the loop.

Can we get the intermediate 5D output with groups=3 without using a loop?