RuntimeError: Function MmBackward returned an invalid gradient at index 0 - got [2048, 256] but expected shape compatible with [2048, 265]

I’m training a perceiver transformer network and I’m trying to replace the explicitly added positional encoding with a positional encoding which is only added to the query and key vectors in the attention mechanism.
When I try to use this special attention I get the error in the title.
I tried stepping through backwards to find out where the error occurs but to no success as the backwards pass is abstrackted away in the c engine.
The relevant code is:

class FEPA_Attention(nn.Module):
    def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., max_freq=512, freq_base=2, num_freq_bands=4):
        super().__init__()
        inner_dim = dim_head * heads
        self.cross = context_dim is not None
        context_dim = default(context_dim, query_dim)
        self.encoding_dim = ((num_freq_bands * 2) + 1)
        self.max_freq = max_freq
        self.num_freq_bands = num_freq_bands
        self.freq_base = freq_base

        self.query_dim = query_dim
        self.scale = dim_head ** -0.5
        self.heads = heads

        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
        self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
        self.to_v = nn.Linear(context_dim, inner_dim, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, query_dim),
            nn.Dropout(dropout)
        )

    @staticmethod
    def get_enc_pos(data, max_freq, num_freq_bands, freq_base):
        b, *axis, _, device = *data.shape, data.device
        axis_pos = list(map(lambda size: torch.linspace(-1., 1., steps=size, device=device), axis))

        pos = torch.stack(torch.meshgrid(*axis_pos), dim=-1)
        enc_pos = fourier_encode(pos, max_freq, num_freq_bands, freq_base)
        enc_pos = rearrange(enc_pos, '... n d -> ... (n d)')
        enc_pos = repeat(enc_pos, '... -> b ...', b=b)
        return enc_pos

    @staticmethod
    def concat_position_encoding(data, max_freq, num_freq_bands, freq_base):
        enc_pos = FEPA_Attention.get_enc_pos(data, max_freq, num_freq_bands, freq_base).detach().requires_grad_(False)
        data = torch.cat((data, enc_pos), dim=-1)
        data = rearrange(data, 'b ... d -> b (...) d')
        return data

    def forward(self, x, context=None, mask=None, get_attention_map=False):
        h = self.heads
        q = FEPA_Attention.concat_position_encoding(x, self.max_freq, self.num_freq_bands, self.freq_base)
        q = self.to_q(q)

        context = default(context, x)
        k = FEPA_Attention.concat_position_encoding(context, self.max_freq, self.num_freq_bands, self.freq_base)
        k = self.to_k(k)

        v = self.to_v(context)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))

        sim = einsum('b i d, b j d -> b i j', q, k) * self.scale

        if exists(mask):
            mask = rearrange(mask, 'b ... -> b (...)')
            max_neg_value = -torch.finfo(sim.dtype).max
            mask = repeat(mask, 'b j -> (b h) () j', h=h)
            sim.masked_fill_(~mask, max_neg_value)

        # attention, what we cannot get enough of
        attn = sim.softmax(dim=-1)

        out = einsum('b i j, b j d -> b i d', attn, v)
        out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
        if get_attention_map:
            return self.to_out(out), attn
        return self.to_out(out)


def fourier_encode(x, max_freq, num_bands=4, base=2):
    x = x.unsqueeze(-1)
    device, dtype, orig_x = x.device, x.dtype, x

    scales = torch.logspace(1., log(max_freq / 2) / log(base), num_bands, base=base, device=device, dtype=dtype)
    scales = scales[(*((None,) * (len(x.shape) - 1)), Ellipsis)]

    x = x * scales * pi
    x = torch.cat([x.sin(), x.cos()], dim=-1)
    x = torch.cat((x, orig_x), dim=-1)
    return x

How can I tackle this problem and find a solution?

I don’t know how common the reported shapes are in your model, but you could try to check the shapes of all intermediate tensors to narrow down the operation(s) which might be failing. It also seems you are using a 3rd party library for rearrange?

I’m using repeat and rearrange from the einops library.
The output shapes of my model are:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
         LayerNorm-1             [16, 128, 256]             512
         LayerNorm-2             [16, 16384, 1]               2
            Linear-3              [16, 128, 16]           4,096
            Linear-4            [16, 16384, 16]              16
            Linear-5            [16, 16384, 16]              16
            Linear-6             [16, 128, 256]           4,352
           Dropout-7             [16, 128, 256]               0
    FEPA_Attention-8             [16, 128, 256]               0
           PreNorm-9             [16, 128, 256]               0
        LayerNorm-10             [16, 128, 256]             512
           Linear-11            [16, 128, 2048]         526,336
            GEGLU-12            [16, 128, 1024]               0
          Dropout-13            [16, 128, 1024]               0
           Linear-14             [16, 128, 256]         262,400
      FeedForward-15             [16, 128, 256]               0
          PreNorm-16             [16, 128, 256]               0
        LayerNorm-17             [16, 128, 256]             512
           Linear-18             [16, 128, 256]          65,536
           Linear-19             [16, 128, 512]         131,072
           Linear-20             [16, 128, 256]          65,792
          Dropout-21             [16, 128, 256]               0
   RBPA_Attention-22             [16, 128, 256]               0
          PreNorm-23             [16, 128, 256]               0
        LayerNorm-24             [16, 128, 256]             512
           Linear-25            [16, 128, 2048]         526,336
            GEGLU-26            [16, 128, 1024]               0
          Dropout-27            [16, 128, 1024]               0
           Linear-28             [16, 128, 256]         262,400
      FeedForward-29             [16, 128, 256]               0
          PreNorm-30             [16, 128, 256]               0
        LayerNorm-31             [16, 128, 256]             512
        LayerNorm-32             [16, 16384, 1]               2
           Linear-33              [16, 128, 16]           4,096
           Linear-34            [16, 16384, 16]              16
           Linear-35            [16, 16384, 16]              16
           Linear-36             [16, 128, 256]           4,352
          Dropout-37             [16, 128, 256]               0
   FEPA_Attention-38             [16, 128, 256]               0
          PreNorm-39             [16, 128, 256]               0
        LayerNorm-40             [16, 128, 256]             512
           Linear-41            [16, 128, 2048]         526,336
            GEGLU-42            [16, 128, 1024]               0
          Dropout-43            [16, 128, 1024]               0
           Linear-44             [16, 128, 256]         262,400
      FeedForward-45             [16, 128, 256]               0
          PreNorm-46             [16, 128, 256]               0
        LayerNorm-47             [16, 128, 256]             512
           Linear-48             [16, 128, 256]          65,536
           Linear-49             [16, 128, 512]         131,072
           Linear-50             [16, 128, 256]          65,792
          Dropout-51             [16, 128, 256]               0
   RBPA_Attention-52             [16, 128, 256]               0
          PreNorm-53             [16, 128, 256]               0
        LayerNorm-54             [16, 128, 256]             512
           Linear-55            [16, 128, 2048]         526,336
            GEGLU-56            [16, 128, 1024]               0
          Dropout-57            [16, 128, 1024]               0
           Linear-58             [16, 128, 256]         262,400
      FeedForward-59             [16, 128, 256]               0
          PreNorm-60             [16, 128, 256]               0
        LayerNorm-61             [16, 128, 256]             512
        LayerNorm-62             [16, 16384, 1]               2
           Linear-63              [16, 128, 16]           4,096
           Linear-64            [16, 16384, 16]              16
           Linear-65            [16, 16384, 16]              16
           Linear-66             [16, 128, 256]           4,352
          Dropout-67             [16, 128, 256]               0
   FEPA_Attention-68             [16, 128, 256]               0
          PreNorm-69             [16, 128, 256]               0
        LayerNorm-70             [16, 128, 256]             512
           Linear-71            [16, 128, 2048]         526,336
            GEGLU-72            [16, 128, 1024]               0
          Dropout-73            [16, 128, 1024]               0
           Linear-74             [16, 128, 256]         262,400
      FeedForward-75             [16, 128, 256]               0
          PreNorm-76             [16, 128, 256]               0
        LayerNorm-77             [16, 128, 256]             512
           Linear-78             [16, 128, 256]          65,536
           Linear-79             [16, 128, 512]         131,072
           Linear-80             [16, 128, 256]          65,792
          Dropout-81             [16, 128, 256]               0
   RBPA_Attention-82             [16, 128, 256]               0
          PreNorm-83             [16, 128, 256]               0
        LayerNorm-84             [16, 128, 256]             512
           Linear-85            [16, 128, 2048]         526,336
            GEGLU-86            [16, 128, 1024]               0
          Dropout-87            [16, 128, 1024]               0
           Linear-88             [16, 128, 256]         262,400
      FeedForward-89             [16, 128, 256]               0
          PreNorm-90             [16, 128, 256]               0
        LayerNorm-91             [16, 128, 256]             512
        LayerNorm-92             [16, 16384, 1]               2
           Linear-93              [16, 128, 16]           4,096
           Linear-94            [16, 16384, 16]              16
           Linear-95            [16, 16384, 16]              16
           Linear-96             [16, 128, 256]           4,352
          Dropout-97             [16, 128, 256]               0
   FEPA_Attention-98             [16, 128, 256]               0
          PreNorm-99             [16, 128, 256]               0
       LayerNorm-100             [16, 128, 256]             512
          Linear-101            [16, 128, 2048]         526,336
           GEGLU-102            [16, 128, 1024]               0
         Dropout-103            [16, 128, 1024]               0
          Linear-104             [16, 128, 256]         262,400
     FeedForward-105             [16, 128, 256]               0
         PreNorm-106             [16, 128, 256]               0
       LayerNorm-107             [16, 128, 256]             512
          Linear-108             [16, 128, 256]          65,536
          Linear-109             [16, 128, 512]         131,072
          Linear-110             [16, 128, 256]          65,792
         Dropout-111             [16, 128, 256]               0
  RBPA_Attention-112             [16, 128, 256]               0
         PreNorm-113             [16, 128, 256]               0
       LayerNorm-114             [16, 128, 256]             512
          Linear-115            [16, 128, 2048]         526,336
           GEGLU-116            [16, 128, 1024]               0
         Dropout-117            [16, 128, 1024]               0
          Linear-118             [16, 128, 256]         262,400
     FeedForward-119             [16, 128, 256]               0
         PreNorm-120             [16, 128, 256]               0
       LayerNorm-121             [16, 128, 256]             512
          Linear-122             [16, 128, 256]          65,536
          Linear-123             [16, 128, 512]         131,072
          Linear-124                  [16, 256]          65,792
         Dropout-125                  [16, 256]               0
CollapsingAttention-126                  [16, 256]               0
         PreNorm-127                  [16, 256]               0
       LayerNorm-128                  [16, 256]             512
          Linear-129                 [16, 2048]         526,336
           GEGLU-130                 [16, 1024]               0
         Dropout-131                 [16, 1024]               0
          Linear-132                  [16, 256]         262,400
     FeedForward-133                  [16, 256]               0
         PreNorm-134                  [16, 256]               0
================================================================
Total params: 8,453,768
Trainable params: 8,453,768
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 1.00
Forward/backward pass size (MB): 1129.75
Params size (MB): 32.25
Estimated Total Size (MB): 1163.00
----------------------------------------------------------------
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
         LayerNorm-1                  [16, 256]             512
            Linear-2                 [16, 2048]         526,336
             GEGLU-3                 [16, 1024]               0
           Dropout-4                 [16, 1024]               0
            Linear-5                  [16, 256]         262,400
       FeedForward-6                  [16, 256]               0
           PreNorm-7                  [16, 256]               0
         LayerNorm-8                  [16, 256]             512
            Linear-9             [16, 128, 256]          65,536
           Linear-10             [16, 128, 512]         131,072
           Linear-11             [16, 128, 256]          65,792
          Dropout-12             [16, 128, 256]               0
ExplodingAttention-13             [16, 128, 256]               0
          PreNorm-14             [16, 128, 256]               0
        LayerNorm-15             [16, 128, 256]             512
           Linear-16             [16, 128, 256]          65,536
           Linear-17             [16, 128, 512]         131,072
           Linear-18             [16, 128, 256]          65,792
          Dropout-19             [16, 128, 256]               0
   RBPA_Attention-20             [16, 128, 256]               0
          PreNorm-21             [16, 128, 256]               0
        LayerNorm-22             [16, 128, 256]             512
           Linear-23            [16, 128, 2048]         526,336
            GEGLU-24            [16, 128, 1024]               0
          Dropout-25            [16, 128, 1024]               0
           Linear-26             [16, 128, 256]         262,400
      FeedForward-27             [16, 128, 256]               0
          PreNorm-28             [16, 128, 256]               0
        LayerNorm-29             [16, 16384, 1]               2
        LayerNorm-30             [16, 128, 256]             512
           Linear-31            [16, 16384, 16]              16
           Linear-32              [16, 128, 16]           4,096
           Linear-33              [16, 128, 16]           4,096
           Linear-34             [16, 16384, 1]              17
          Dropout-35             [16, 16384, 1]               0
   FEPA_Attention-36             [16, 16384, 1]               0
          PreNorm-37             [16, 16384, 1]               0
        LayerNorm-38             [16, 16384, 1]               2
           Linear-39             [16, 16384, 8]              16
            GEGLU-40             [16, 16384, 4]               0
          Dropout-41             [16, 16384, 4]               0
           Linear-42             [16, 16384, 1]               5
      FeedForward-43             [16, 16384, 1]               0
          PreNorm-44             [16, 16384, 1]               0
        LayerNorm-45             [16, 128, 256]             512
           Linear-46             [16, 128, 256]          65,536
           Linear-47             [16, 128, 512]         131,072
           Linear-48             [16, 128, 256]          65,792
          Dropout-49             [16, 128, 256]               0
   RBPA_Attention-50             [16, 128, 256]               0
          PreNorm-51             [16, 128, 256]               0
        LayerNorm-52             [16, 128, 256]             512
           Linear-53            [16, 128, 2048]         526,336
            GEGLU-54            [16, 128, 1024]               0
          Dropout-55            [16, 128, 1024]               0
           Linear-56             [16, 128, 256]         262,400
      FeedForward-57             [16, 128, 256]               0
          PreNorm-58             [16, 128, 256]               0
        LayerNorm-59             [16, 16384, 1]               2
        LayerNorm-60             [16, 128, 256]             512
           Linear-61            [16, 16384, 16]              16
           Linear-62              [16, 128, 16]           4,096
           Linear-63              [16, 128, 16]           4,096
           Linear-64             [16, 16384, 1]              17
          Dropout-65             [16, 16384, 1]               0
   FEPA_Attention-66             [16, 16384, 1]               0
          PreNorm-67             [16, 16384, 1]               0
        LayerNorm-68             [16, 16384, 1]               2
           Linear-69             [16, 16384, 8]              16
            GEGLU-70             [16, 16384, 4]               0
          Dropout-71             [16, 16384, 4]               0
           Linear-72             [16, 16384, 1]               5
      FeedForward-73             [16, 16384, 1]               0
          PreNorm-74             [16, 16384, 1]               0
        LayerNorm-75             [16, 128, 256]             512
           Linear-76             [16, 128, 256]          65,536
           Linear-77             [16, 128, 512]         131,072
           Linear-78             [16, 128, 256]          65,792
          Dropout-79             [16, 128, 256]               0
   RBPA_Attention-80             [16, 128, 256]               0
          PreNorm-81             [16, 128, 256]               0
        LayerNorm-82             [16, 128, 256]             512
           Linear-83            [16, 128, 2048]         526,336
            GEGLU-84            [16, 128, 1024]               0
          Dropout-85            [16, 128, 1024]               0
           Linear-86             [16, 128, 256]         262,400
      FeedForward-87             [16, 128, 256]               0
          PreNorm-88             [16, 128, 256]               0
        LayerNorm-89             [16, 16384, 1]               2
        LayerNorm-90             [16, 128, 256]             512
           Linear-91            [16, 16384, 16]              16
           Linear-92              [16, 128, 16]           4,096
           Linear-93              [16, 128, 16]           4,096
           Linear-94             [16, 16384, 1]              17
          Dropout-95             [16, 16384, 1]               0
   FEPA_Attention-96             [16, 16384, 1]               0
          PreNorm-97             [16, 16384, 1]               0
        LayerNorm-98             [16, 16384, 1]               2
           Linear-99             [16, 16384, 8]              16
           GEGLU-100             [16, 16384, 4]               0
         Dropout-101             [16, 16384, 4]               0
          Linear-102             [16, 16384, 1]               5
     FeedForward-103             [16, 16384, 1]               0
         PreNorm-104             [16, 16384, 1]               0
       LayerNorm-105             [16, 128, 256]             512
          Linear-106             [16, 128, 256]          65,536
          Linear-107             [16, 128, 512]         131,072
          Linear-108             [16, 128, 256]          65,792
         Dropout-109             [16, 128, 256]               0
  RBPA_Attention-110             [16, 128, 256]               0
         PreNorm-111             [16, 128, 256]               0
       LayerNorm-112             [16, 128, 256]             512
          Linear-113            [16, 128, 2048]         526,336
           GEGLU-114            [16, 128, 1024]               0
         Dropout-115            [16, 128, 1024]               0
          Linear-116             [16, 128, 256]         262,400
     FeedForward-117             [16, 128, 256]               0
         PreNorm-118             [16, 128, 256]               0
       LayerNorm-119             [16, 16384, 1]               2
       LayerNorm-120             [16, 128, 256]             512
          Linear-121            [16, 16384, 16]              16
          Linear-122              [16, 128, 16]           4,096
          Linear-123              [16, 128, 16]           4,096
          Linear-124             [16, 16384, 1]              17
         Dropout-125             [16, 16384, 1]               0
  FEPA_Attention-126             [16, 16384, 1]               0
         PreNorm-127             [16, 16384, 1]               0
       LayerNorm-128             [16, 16384, 1]               2
          Linear-129             [16, 16384, 8]              16
           GEGLU-130             [16, 16384, 4]               0
         Dropout-131             [16, 16384, 4]               0
          Linear-132             [16, 16384, 1]               5
     FeedForward-133             [16, 16384, 1]               0
         PreNorm-134             [16, 16384, 1]               0
      DeceiverLT-135          [16, 128, 128, 1]               0
================================================================
Total params: 5,295,848
Trainable params: 5,295,848
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.02
Forward/backward pass size (MB): 824.66
Params size (MB): 20.20
Estimated Total Size (MB): 844.87
----------------------------------------------------------------

I think the error happens in the backwards pass between the FeedForward layer and the to_out layer of the FEPA_Attention Module.
Printing out the module and grads via a backwards hook seems to confirm this:

Module: [FEPA_Attention(
  (to_q): Linear(in_features=1, out_features=16, bias=False)
  (to_k): Linear(in_features=256, out_features=16, bias=False)
  (to_v): Linear(in_features=256, out_features=16, bias=False)
  (to_out): Sequential(
    (0): Linear(in_features=16, out_features=1, bias=True)
    (1): Dropout(p=0.0, inplace=False)
  )
), grad_in: [[torch.Size([16, 16384, 1]), torch.Size([1])]], grad_out: [[torch.Size([16, 16384, 1])]]
Epoch 0:   0%|          | 0/1123 [00:00<?, ?it/s]
2021-09-24 11:54:11,306 ERROR Something went wrong in the experiment 
Traceback (most recent call last):
  File "C:/Users/Maxim/Documents/Uni/Bachelorarbeit/impress/experiment/main.py", line 88, in main
    experiment.run()
  File "C:\Users\Maxim\Documents\Uni\Bachelorarbeit\impress\experiment\experiment_perceiver.py", line 265, in run
    self.training(setup=setup, epoch=epoch, menu=self.keyboard_menu)
  File "C:\Users\Maxim\Documents\Uni\Bachelorarbeit\impress\experiment\experiment_perceiver.py", line 353, in __call__
    loss = self.train(models, losses, optimizers, imgs)
  File "C:\Users\Maxim\Documents\Uni\Bachelorarbeit\impress\experiment\experiment_perceiver.py", line 328, in train
    loss.backward()
  File "C:\Users\Maxim\Documents\Uni\Bachelorarbeit\impress\venv\lib\site-packages\torch\_tensor.py", line 255, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "C:\Users\Maxim\Documents\Uni\Bachelorarbeit\impress\venv\lib\site-packages\torch\autograd\__init__.py", line 147, in backward
    Variable._execution_engine.run_backward(
RuntimeError: Function MmBackward returned an invalid gradient at index 0 - got [2048, 256] but expected shape compatible with [2048, 265]

Process finished with exit code 0

But I can’t see where the error happens :confused: