Attention in ConvGRU

Hello, the size of the output from my decoder GRU is (4, 1, 256, 4, 4) and the size of the hidden layer from my encoder is (4, 5, 256, 4, 4). Where the order of the dimensions is: (batch, time_step, image_depth, height, width). I use the following piece of code to compute the attention:

attn_energies = torch.sum(decoder_output * encoder_output, dim=2)
attn_energies = attn_energies.permute(0, 3, 2, 1)
final_output = F.softmax(attn_energies, dim=1).unsqueeze(1)

The result of the above is a new tensor of dimensions: (4, 5, 4, 4). Is this correct? I am following the official PyTorch ChatBot tutorial: https://pytorch.org/tutorials/beginner/chatbot_tutorial.html#decoder

How do I apply the next step, namely batch multiplication? It’s a bit confusing because here there are images and not word embeddings

Edit: This is the method I use to calculate attention. Please tell me whether this is right:

            #hidden[-1] has size (4, 1, 256, 4, 4), all_hidden has size (4, 5, 256, 4, 4)
            attn_weights = self.attn(hidden[-1], all_hidden)
            #attn_weights has size (4, 1, 256, 4, 4)
            #self.relu(p_tmp) is the output from all the GRU cells in encoder stage, size (4, 256, 4, 4)
            #self.relu(p_tmp).unsqueeze(1) has size (4, 1, 256, 4, 4)
            context = attn_weights.matmul(self.relu(p_tmp).unsqueeze(1))
            #context has size (4, 1, 256, 4, 4)
            concat_input = torch.cat((hidden[-1], context), 1)
            #concat_input has size (4, 2, 256, 4, 4)
            (B, N, C, H, W) = concat_input.shape
            concat_input = concat_input.view(B * N, C, H, W)
            #concat_input now has size (8, 256, 4, 4)
            concat_output = torch.tanh(self.concat(concat_input))
            #concat_output has size (8, 256, 4, 4)
            hidden = self.out(concat_output)
            #size of hidden is (8, 256, 4, 4)
            hidden = hidden.view(B, N, C, H, W)
            #size of hidden is (4, 2, 256, 4, 4)
            hidden = hidden.permute(0, 2, 1, 3, 4)
            #size of hidden is (4, 256, 2, 4, 4)
            hidden = F.avg_pool3d(hidden, (self.last_duration, 1, 1), stride=(1, 1, 1))
            hidden = hidden.permute(0, 2, 1, 3, 4)
            #size of hidden is (4, 1, 256, 4, 4)
            hidden = hidden[:,-1,:]
            #and final size of hidden is (4, 256, 4, 4)

where:

def attn(hidden, encoder_outputs):
    attn_energies = torch.sum(hidden * encoder_output, dim=1)
    first = F.softmax(attn_energies, dim=2)
    second = F.softmax(first, dim=3)
    return second.unsqueeze(1)
self.concat = nn.Sequential(
            nn.Conv2d(self.param['feature_size'], self.param['feature_size'], kernel_size=1, padding=0),
            nn.ReLU(inplace=True)
        )
self.out = nn.Sequential(
            nn.Conv2d(self.param['feature_size'], self.param['feature_size'], kernel_size=1, padding=0),
            nn.ReLU(inplace=True)
        )
#self.param['feature_size'] is 256