Implementation of the Dense Synthesizer

I’m trying to understand the Synthesizer paper (https://arxiv.org/pdf/2005.00743.pdf) and there’s a description of the dense synthesizer mechanism that should replace the traditional attention model as described in the Transformer architecture.

The Dense Synthesizer is described as such:

So I tried to implement the layer and it looks like this but I’m not sure whether I’m getting it right:

class DenseSynthesizer(nn.Module):
    def __init__(self, l, d):
        super(DenseSynthesizer, self).__init__()
        self.linear1 = nn.Linear(d, l)
        self.linear2 = nn.Linear(l, l)
    
    def forward(self, x, v):
        # Equation (1) and (2)
        # Shape: l x l
        b = self.linear2(F.relu(self.linear1(x)))   
        # Equation (3)
        # [l x l] x [l x d] -> [l x d]
        return torch.matmul(F.softmax(b), v) 

Usage:

l, d = 4, 5

x, v =  torch.rand(l, d), torch.rand(l, d)

synthesis = DenseSynthesizer(l, d)
synthesis(x, v) 

Example:

x and v are tensors:

x = tensor([[0.0844, 0.2683, 0.4299, 0.1827, 0.1188],
         [0.2793, 0.0389, 0.3834, 0.9897, 0.4197],
         [0.1420, 0.8051, 0.1601, 0.3299, 0.3340],
         [0.8908, 0.1066, 0.1140, 0.7145, 0.3619]])

v = tensor([[0.3806, 0.1775, 0.5457, 0.6746, 0.4505],
         [0.6309, 0.2790, 0.7215, 0.4283, 0.5853],
         [0.7548, 0.6887, 0.0426, 0.1057, 0.7895],
         [0.1881, 0.5334, 0.6834, 0.4845, 0.1960]])

And passing through a forward pass through the dense synthesis, it returns:

>>> synthesis = DenseSynthesizer(l, d)
>>> synthesis(x, v) 

tensor([[0.5371, 0.4528, 0.4560, 0.3735, 0.5492],
        [0.5426, 0.4434, 0.4625, 0.3770, 0.5536],
        [0.5362, 0.4477, 0.4658, 0.3769, 0.5468],
        [0.5430, 0.4461, 0.4559, 0.3755, 0.5551]], grad_fn=<MmBackward>)

Is the implementation and understanding of the dense synthesizer correct?

Theoretically, how is that different from a multi-layered perceptron that takes in two different inputs and makes uses of it at different point in the forward propagation?

Comparatively, the original transformer attention mechanism is as such (from https://nlp.seas.harvard.edu/2018/04/03/attention.html):

def attention(query, key, value, mask=None, dropout=None):
    """Compute 'Scaled Dot Product Attention"""
    d_k = query.size(-1) # no. of hidden dimension of query tensor.
    scores = torch.matmul(query, key.transpose(-2, -1)) \
             / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    p_attn = F.softmax(scores, dim = -1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn

Usage:

>>> l, d = 4, 5
>>> q, k, v = torch.rand(l, d), torch.rand(l, d), torch.rand(l, d)
>>> attention(q, k, v)
(tensor([[0.5511, 0.6426, 0.6486, 0.4216, 0.3083],
         [0.5560, 0.6441, 0.6665, 0.4169, 0.3312],
         [0.5590, 0.6426, 0.6647, 0.4173, 0.3283],
         [0.5678, 0.6450, 0.6291, 0.4012, 0.2998]]),
 tensor([[0.2579, 0.2811, 0.2288, 0.2322],
         [0.2724, 0.2851, 0.2585, 0.1840],
         [0.2993, 0.2884, 0.2382, 0.1740],
         [0.2769, 0.2465, 0.2166, 0.2600]]))