I have the following bottleneck in my model:
I need to normalize groups of elements in my input tensors. The groups of elements that need to be normalized together are given by lists. Those lists are specific to a given layer in the model and do not change when another input tensor is treated.
In the example below, the input tensor ln_input
is of shape [16,14]
, and there are 2 groups of elements to normalize using softmax:
- Elements:
[0,0],[0,1],[1,0],[1,1]
- Elements:
[10,10],[10,11]
Those 2 groups are each saved as a list of tuples in the dictionary coord_mapping
.
coord_mapping
is fixed, while there are many different ln_input
for which I need to repeat the operation. Therefore any expensive modification of coord_mapping
data structure is acceptable.
Here is a minimal working example:
import torch
import torch.nn as nn
from collections import defaultdict
## normalization function
softmax0 = nn.Softmax(dim=0)
## input tensor
h_in = 16 # input height
w_in = 14 # input width
ln_input = torch.zeros([h_in,w_in])
output = torch.zeros_like(ln_input)
## groups of elements to normalize together
coord_mapping = defaultdict(list)
coord_mapping[1].append((0,0))
coord_mapping[1].append((0,1))
coord_mapping[1].append((1,0))
coord_mapping[1].append((1,1))
coord_mapping[2].append((10,10))
coord_mapping[2].append((10,11))
## inefficient normalization
for _, value in coord_mapping.items():
h_list, w_list = zip(*value)
output[h_list, w_list] = softmax0(ln_input[h_list, w_list])
## those groups of elements each sum to 1
print(output[0,0],output[0,1],output[1,0],output[1,1])
print(output[10,10],output[10,11])
How could I make this more efficient?