Einsum doubt in formulation

Hi,

I just wanna know, is there any difference in the output of einsum of below mentioned two formulation.

torch.einsum(“bhld,lrd->bhlr”, query_layer, positional_embedding)
torch.einsum(“bhrd,lrd->bhlr”, query_layer, positional_embedding)

Any help is much appreciated!

Jay

Yes, there is, as the third axis of the first input tensor is aligned with dfferent axes in the second input and output.
If I want to find out about these things, I usually choose random inputs of all different dimensions, e.g.

query_layer = torch.randn(2, 3, 4, 5) # b h l d
positional_embedding = torch.randn(4, 6, 5)  # l r d

with this, you see that only the first einsum will work.
Now in terms of what it should be, it’s not entirely clear to me.
Typical dimensions at work in attention layers are batch, head, query position, embedding position = key, and feature. If we guess batch and head and feature dimension, we have l and r for query and key, but then it is very atypical to have lrd as the second factor and (if we assume l=query, r=key) it would be more common to have hrd or more commonly bhrd for the second factor.
In my small toroidal library, I implemented a standard (BERT, GPT, ViT, …) attention module with einsum. I number the key and query positions s and t and the features c (for channel), and then have use bthc,bshc->bhts after ordering the dimensions.

Best regards

Thomas

Thank you!

Also one more einsum for fourier transform is kind of not working.

class FourierMMLayer(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.dft_mat_seq = torch.tensor(linalg.dft(512))
        self.dft_mat_hidden = torch.tensor(linalg.dft(768))

    def forward(self, hidden_states):
        hidden_states_complex = hidden_states.type(torch.complex128)
         #pre fourier torch.Size([2, 9, 768]) of hidden states
        return torch.einsum(
            "...ij,...jk,...ni->...nk",
            hidden_states_complex,
            self.dft_mat_hidden,
            self.dft_mat_seq
        ).real.type(torch.float32)
Traceback (most recent call last):
  File "inference.py", line 22, in <module>
    obj1.forward(input_ids, token_type_ids)
  File "/mnt/sda1/ml_models/fourier_net/fnet.py", line 124, in forward
    self.encoder(input_ids, type_ids)
  File "/mnt/sda1/luck/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/mnt/sda1/ml_models/fourier_net/fnet.py", line 113, in forward
    sequence_output = self.encoder(embedding_output)
  File "/mnt/sda1/luck/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/mnt/sda1/ml_models/fourier_net/fnet.py", line 94, in forward
    hidden_states = layer_module(hidden_states)
  File "/mnt/sda1/luck/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/mnt/sda1/ml_models/fourier_net/fnet.py", line 80, in forward
    fft_output = self.fft(hidden_states)
  File "/mnt/sda1/luck/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/mnt/sda1/ml_models/fourier_net/fnet.py", line 62, in forward
    return torch.einsum(
  File "/mnt/sda1/luck/lib/python3.8/site-packages/torch/functional.py", line 299, in einsum
    return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
RuntimeError: einsum(): operands do not broadcast with remapped shapes [original->remapped]: [2, 9, 768]->[2, 1, 1, 9, 768] [768, 768]->[1, 1, 768, 1, 768] [512, 512]->[1, 512, 1, 512, 1]

Can you please help?

You dimension i is 9 in the first input, and 512 in the last, that doesn’t seem right.

Shape is correct @tom
Also that is the input embedding shape [batch_size, seq_len, embeddings]
512 is number of positions with 768 embeddings.
That seems ok to me.

I have implemented from this source

I am hesitant to insist, but so the error message literally means you have a shape mismatch.
The best way to look into this would be that you print out the dimensions of your inputs and write a letter for each in the einsum formula next to it. Then it should get more clear what is going on.

Best regards

Thomas