How to efficiently normalize groups of elements in a tensor

Context:

I am trying to replicate Hinton’s “Matrix capsules with EM routing” (https://openreview.net/forum?id=HJWLfGWRb).

At some point, there are convolutional operations that are performed (in the sense that an output tensor is connected to an input tensor, and each element in the output tensor is influenced only by the input elements contained in a 2D mask of size K).

Input
Tensor x of shape w_in,w_in
where

  • w_in=14

Intermediary tensor mapped to the input
Tensor x_mapped of shape w_out,w_out,K,K
where

  • K=3 is the convolution kernel size
  • w_out=6, resulting from convolution with stride=2

Summing on the dimensions 2 and 3 (both of size K) means summing on the input elements connected to an output element whose location is given by dimensions 0 and 1.

Question:

How can I efficiently normalize (to 1) groups of elements in x_mapped, based on their location in the input tensor x?

For example:
x_mapped(0,0,2,2)
x_mapped(1,0,0,2)
x_mapped(0,1,2,0)
x_mapped(1,1,0,0)
are all connected to x(2,2) (the formula is i_out*stride + K_index = i_in). For that reason, I would like the sum of those 4 elements to be 1.

And I would like to do that for all the groups of elements in x_mapped that are “connected” to the same element in x.

I can figure out how to do it by:

  1. Building a dictionary with input location as key and list of output elements as value
  2. Looping on the dictionary, summing the elements in the list for a given input location and dividing them by that sum

but that seems really inefficient to me.

I solved this in the following way:

  1. Creating a dictionary with a 2-tuple as key (coordinates in x) and a list of elements of x_mapped as values.
  2. One loop over the dictionary, zipping all the elements of one dictionary item, then normalizing.

Here is the code:

from collections import defaultdict
import torch

ho = 6
wo = 6
stride = 2
K = 3

d = defaultdict(list)

x_mapped = torch.arange(0,ho*wo*K*K).view(ho,wo,K,K).type(dtype = torch.DoubleTensor)

for i_out in range(0,ho):
    for j_out in range(0,wo):
        for K_i in range(0,K):
            for K_j in range(0, K):
                i_in = i_out * stride + K_i
                j_in = j_out * stride + K_j

                d[(i_in, j_in)].append((i_out, j_out, K_i, K_j))

for _ , value in d.items():
    ho_list, wo_list, K_i_list, K_j_list = zip(*value)
    x_mapped[ho_list, wo_list, K_i_list, K_j_list] = x_mapped[ho_list, wo_list, K_i_list, K_j_list] / torch.sum(
        x_mapped[ho_list, wo_list, K_i_list, K_j_list])

The clean solution I found was to use Torch.fold()