I’m using repeat and rearrange from the einops library.
The output shapes of my model are:
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
LayerNorm-1 [16, 128, 256] 512
LayerNorm-2 [16, 16384, 1] 2
Linear-3 [16, 128, 16] 4,096
Linear-4 [16, 16384, 16] 16
Linear-5 [16, 16384, 16] 16
Linear-6 [16, 128, 256] 4,352
Dropout-7 [16, 128, 256] 0
FEPA_Attention-8 [16, 128, 256] 0
PreNorm-9 [16, 128, 256] 0
LayerNorm-10 [16, 128, 256] 512
Linear-11 [16, 128, 2048] 526,336
GEGLU-12 [16, 128, 1024] 0
Dropout-13 [16, 128, 1024] 0
Linear-14 [16, 128, 256] 262,400
FeedForward-15 [16, 128, 256] 0
PreNorm-16 [16, 128, 256] 0
LayerNorm-17 [16, 128, 256] 512
Linear-18 [16, 128, 256] 65,536
Linear-19 [16, 128, 512] 131,072
Linear-20 [16, 128, 256] 65,792
Dropout-21 [16, 128, 256] 0
RBPA_Attention-22 [16, 128, 256] 0
PreNorm-23 [16, 128, 256] 0
LayerNorm-24 [16, 128, 256] 512
Linear-25 [16, 128, 2048] 526,336
GEGLU-26 [16, 128, 1024] 0
Dropout-27 [16, 128, 1024] 0
Linear-28 [16, 128, 256] 262,400
FeedForward-29 [16, 128, 256] 0
PreNorm-30 [16, 128, 256] 0
LayerNorm-31 [16, 128, 256] 512
LayerNorm-32 [16, 16384, 1] 2
Linear-33 [16, 128, 16] 4,096
Linear-34 [16, 16384, 16] 16
Linear-35 [16, 16384, 16] 16
Linear-36 [16, 128, 256] 4,352
Dropout-37 [16, 128, 256] 0
FEPA_Attention-38 [16, 128, 256] 0
PreNorm-39 [16, 128, 256] 0
LayerNorm-40 [16, 128, 256] 512
Linear-41 [16, 128, 2048] 526,336
GEGLU-42 [16, 128, 1024] 0
Dropout-43 [16, 128, 1024] 0
Linear-44 [16, 128, 256] 262,400
FeedForward-45 [16, 128, 256] 0
PreNorm-46 [16, 128, 256] 0
LayerNorm-47 [16, 128, 256] 512
Linear-48 [16, 128, 256] 65,536
Linear-49 [16, 128, 512] 131,072
Linear-50 [16, 128, 256] 65,792
Dropout-51 [16, 128, 256] 0
RBPA_Attention-52 [16, 128, 256] 0
PreNorm-53 [16, 128, 256] 0
LayerNorm-54 [16, 128, 256] 512
Linear-55 [16, 128, 2048] 526,336
GEGLU-56 [16, 128, 1024] 0
Dropout-57 [16, 128, 1024] 0
Linear-58 [16, 128, 256] 262,400
FeedForward-59 [16, 128, 256] 0
PreNorm-60 [16, 128, 256] 0
LayerNorm-61 [16, 128, 256] 512
LayerNorm-62 [16, 16384, 1] 2
Linear-63 [16, 128, 16] 4,096
Linear-64 [16, 16384, 16] 16
Linear-65 [16, 16384, 16] 16
Linear-66 [16, 128, 256] 4,352
Dropout-67 [16, 128, 256] 0
FEPA_Attention-68 [16, 128, 256] 0
PreNorm-69 [16, 128, 256] 0
LayerNorm-70 [16, 128, 256] 512
Linear-71 [16, 128, 2048] 526,336
GEGLU-72 [16, 128, 1024] 0
Dropout-73 [16, 128, 1024] 0
Linear-74 [16, 128, 256] 262,400
FeedForward-75 [16, 128, 256] 0
PreNorm-76 [16, 128, 256] 0
LayerNorm-77 [16, 128, 256] 512
Linear-78 [16, 128, 256] 65,536
Linear-79 [16, 128, 512] 131,072
Linear-80 [16, 128, 256] 65,792
Dropout-81 [16, 128, 256] 0
RBPA_Attention-82 [16, 128, 256] 0
PreNorm-83 [16, 128, 256] 0
LayerNorm-84 [16, 128, 256] 512
Linear-85 [16, 128, 2048] 526,336
GEGLU-86 [16, 128, 1024] 0
Dropout-87 [16, 128, 1024] 0
Linear-88 [16, 128, 256] 262,400
FeedForward-89 [16, 128, 256] 0
PreNorm-90 [16, 128, 256] 0
LayerNorm-91 [16, 128, 256] 512
LayerNorm-92 [16, 16384, 1] 2
Linear-93 [16, 128, 16] 4,096
Linear-94 [16, 16384, 16] 16
Linear-95 [16, 16384, 16] 16
Linear-96 [16, 128, 256] 4,352
Dropout-97 [16, 128, 256] 0
FEPA_Attention-98 [16, 128, 256] 0
PreNorm-99 [16, 128, 256] 0
LayerNorm-100 [16, 128, 256] 512
Linear-101 [16, 128, 2048] 526,336
GEGLU-102 [16, 128, 1024] 0
Dropout-103 [16, 128, 1024] 0
Linear-104 [16, 128, 256] 262,400
FeedForward-105 [16, 128, 256] 0
PreNorm-106 [16, 128, 256] 0
LayerNorm-107 [16, 128, 256] 512
Linear-108 [16, 128, 256] 65,536
Linear-109 [16, 128, 512] 131,072
Linear-110 [16, 128, 256] 65,792
Dropout-111 [16, 128, 256] 0
RBPA_Attention-112 [16, 128, 256] 0
PreNorm-113 [16, 128, 256] 0
LayerNorm-114 [16, 128, 256] 512
Linear-115 [16, 128, 2048] 526,336
GEGLU-116 [16, 128, 1024] 0
Dropout-117 [16, 128, 1024] 0
Linear-118 [16, 128, 256] 262,400
FeedForward-119 [16, 128, 256] 0
PreNorm-120 [16, 128, 256] 0
LayerNorm-121 [16, 128, 256] 512
Linear-122 [16, 128, 256] 65,536
Linear-123 [16, 128, 512] 131,072
Linear-124 [16, 256] 65,792
Dropout-125 [16, 256] 0
CollapsingAttention-126 [16, 256] 0
PreNorm-127 [16, 256] 0
LayerNorm-128 [16, 256] 512
Linear-129 [16, 2048] 526,336
GEGLU-130 [16, 1024] 0
Dropout-131 [16, 1024] 0
Linear-132 [16, 256] 262,400
FeedForward-133 [16, 256] 0
PreNorm-134 [16, 256] 0
================================================================
Total params: 8,453,768
Trainable params: 8,453,768
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 1.00
Forward/backward pass size (MB): 1129.75
Params size (MB): 32.25
Estimated Total Size (MB): 1163.00
----------------------------------------------------------------
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
LayerNorm-1 [16, 256] 512
Linear-2 [16, 2048] 526,336
GEGLU-3 [16, 1024] 0
Dropout-4 [16, 1024] 0
Linear-5 [16, 256] 262,400
FeedForward-6 [16, 256] 0
PreNorm-7 [16, 256] 0
LayerNorm-8 [16, 256] 512
Linear-9 [16, 128, 256] 65,536
Linear-10 [16, 128, 512] 131,072
Linear-11 [16, 128, 256] 65,792
Dropout-12 [16, 128, 256] 0
ExplodingAttention-13 [16, 128, 256] 0
PreNorm-14 [16, 128, 256] 0
LayerNorm-15 [16, 128, 256] 512
Linear-16 [16, 128, 256] 65,536
Linear-17 [16, 128, 512] 131,072
Linear-18 [16, 128, 256] 65,792
Dropout-19 [16, 128, 256] 0
RBPA_Attention-20 [16, 128, 256] 0
PreNorm-21 [16, 128, 256] 0
LayerNorm-22 [16, 128, 256] 512
Linear-23 [16, 128, 2048] 526,336
GEGLU-24 [16, 128, 1024] 0
Dropout-25 [16, 128, 1024] 0
Linear-26 [16, 128, 256] 262,400
FeedForward-27 [16, 128, 256] 0
PreNorm-28 [16, 128, 256] 0
LayerNorm-29 [16, 16384, 1] 2
LayerNorm-30 [16, 128, 256] 512
Linear-31 [16, 16384, 16] 16
Linear-32 [16, 128, 16] 4,096
Linear-33 [16, 128, 16] 4,096
Linear-34 [16, 16384, 1] 17
Dropout-35 [16, 16384, 1] 0
FEPA_Attention-36 [16, 16384, 1] 0
PreNorm-37 [16, 16384, 1] 0
LayerNorm-38 [16, 16384, 1] 2
Linear-39 [16, 16384, 8] 16
GEGLU-40 [16, 16384, 4] 0
Dropout-41 [16, 16384, 4] 0
Linear-42 [16, 16384, 1] 5
FeedForward-43 [16, 16384, 1] 0
PreNorm-44 [16, 16384, 1] 0
LayerNorm-45 [16, 128, 256] 512
Linear-46 [16, 128, 256] 65,536
Linear-47 [16, 128, 512] 131,072
Linear-48 [16, 128, 256] 65,792
Dropout-49 [16, 128, 256] 0
RBPA_Attention-50 [16, 128, 256] 0
PreNorm-51 [16, 128, 256] 0
LayerNorm-52 [16, 128, 256] 512
Linear-53 [16, 128, 2048] 526,336
GEGLU-54 [16, 128, 1024] 0
Dropout-55 [16, 128, 1024] 0
Linear-56 [16, 128, 256] 262,400
FeedForward-57 [16, 128, 256] 0
PreNorm-58 [16, 128, 256] 0
LayerNorm-59 [16, 16384, 1] 2
LayerNorm-60 [16, 128, 256] 512
Linear-61 [16, 16384, 16] 16
Linear-62 [16, 128, 16] 4,096
Linear-63 [16, 128, 16] 4,096
Linear-64 [16, 16384, 1] 17
Dropout-65 [16, 16384, 1] 0
FEPA_Attention-66 [16, 16384, 1] 0
PreNorm-67 [16, 16384, 1] 0
LayerNorm-68 [16, 16384, 1] 2
Linear-69 [16, 16384, 8] 16
GEGLU-70 [16, 16384, 4] 0
Dropout-71 [16, 16384, 4] 0
Linear-72 [16, 16384, 1] 5
FeedForward-73 [16, 16384, 1] 0
PreNorm-74 [16, 16384, 1] 0
LayerNorm-75 [16, 128, 256] 512
Linear-76 [16, 128, 256] 65,536
Linear-77 [16, 128, 512] 131,072
Linear-78 [16, 128, 256] 65,792
Dropout-79 [16, 128, 256] 0
RBPA_Attention-80 [16, 128, 256] 0
PreNorm-81 [16, 128, 256] 0
LayerNorm-82 [16, 128, 256] 512
Linear-83 [16, 128, 2048] 526,336
GEGLU-84 [16, 128, 1024] 0
Dropout-85 [16, 128, 1024] 0
Linear-86 [16, 128, 256] 262,400
FeedForward-87 [16, 128, 256] 0
PreNorm-88 [16, 128, 256] 0
LayerNorm-89 [16, 16384, 1] 2
LayerNorm-90 [16, 128, 256] 512
Linear-91 [16, 16384, 16] 16
Linear-92 [16, 128, 16] 4,096
Linear-93 [16, 128, 16] 4,096
Linear-94 [16, 16384, 1] 17
Dropout-95 [16, 16384, 1] 0
FEPA_Attention-96 [16, 16384, 1] 0
PreNorm-97 [16, 16384, 1] 0
LayerNorm-98 [16, 16384, 1] 2
Linear-99 [16, 16384, 8] 16
GEGLU-100 [16, 16384, 4] 0
Dropout-101 [16, 16384, 4] 0
Linear-102 [16, 16384, 1] 5
FeedForward-103 [16, 16384, 1] 0
PreNorm-104 [16, 16384, 1] 0
LayerNorm-105 [16, 128, 256] 512
Linear-106 [16, 128, 256] 65,536
Linear-107 [16, 128, 512] 131,072
Linear-108 [16, 128, 256] 65,792
Dropout-109 [16, 128, 256] 0
RBPA_Attention-110 [16, 128, 256] 0
PreNorm-111 [16, 128, 256] 0
LayerNorm-112 [16, 128, 256] 512
Linear-113 [16, 128, 2048] 526,336
GEGLU-114 [16, 128, 1024] 0
Dropout-115 [16, 128, 1024] 0
Linear-116 [16, 128, 256] 262,400
FeedForward-117 [16, 128, 256] 0
PreNorm-118 [16, 128, 256] 0
LayerNorm-119 [16, 16384, 1] 2
LayerNorm-120 [16, 128, 256] 512
Linear-121 [16, 16384, 16] 16
Linear-122 [16, 128, 16] 4,096
Linear-123 [16, 128, 16] 4,096
Linear-124 [16, 16384, 1] 17
Dropout-125 [16, 16384, 1] 0
FEPA_Attention-126 [16, 16384, 1] 0
PreNorm-127 [16, 16384, 1] 0
LayerNorm-128 [16, 16384, 1] 2
Linear-129 [16, 16384, 8] 16
GEGLU-130 [16, 16384, 4] 0
Dropout-131 [16, 16384, 4] 0
Linear-132 [16, 16384, 1] 5
FeedForward-133 [16, 16384, 1] 0
PreNorm-134 [16, 16384, 1] 0
DeceiverLT-135 [16, 128, 128, 1] 0
================================================================
Total params: 5,295,848
Trainable params: 5,295,848
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.02
Forward/backward pass size (MB): 824.66
Params size (MB): 20.20
Estimated Total Size (MB): 844.87
----------------------------------------------------------------
I think the error happens in the backwards pass between the FeedForward layer and the to_out layer of the FEPA_Attention Module.
Printing out the module and grads via a backwards hook seems to confirm this:
Module: [FEPA_Attention(
(to_q): Linear(in_features=1, out_features=16, bias=False)
(to_k): Linear(in_features=256, out_features=16, bias=False)
(to_v): Linear(in_features=256, out_features=16, bias=False)
(to_out): Sequential(
(0): Linear(in_features=16, out_features=1, bias=True)
(1): Dropout(p=0.0, inplace=False)
)
), grad_in: [[torch.Size([16, 16384, 1]), torch.Size([1])]], grad_out: [[torch.Size([16, 16384, 1])]]
Epoch 0: 0%| | 0/1123 [00:00<?, ?it/s]
2021-09-24 11:54:11,306 ERROR Something went wrong in the experiment
Traceback (most recent call last):
File "C:/Users/Maxim/Documents/Uni/Bachelorarbeit/impress/experiment/main.py", line 88, in main
experiment.run()
File "C:\Users\Maxim\Documents\Uni\Bachelorarbeit\impress\experiment\experiment_perceiver.py", line 265, in run
self.training(setup=setup, epoch=epoch, menu=self.keyboard_menu)
File "C:\Users\Maxim\Documents\Uni\Bachelorarbeit\impress\experiment\experiment_perceiver.py", line 353, in __call__
loss = self.train(models, losses, optimizers, imgs)
File "C:\Users\Maxim\Documents\Uni\Bachelorarbeit\impress\experiment\experiment_perceiver.py", line 328, in train
loss.backward()
File "C:\Users\Maxim\Documents\Uni\Bachelorarbeit\impress\venv\lib\site-packages\torch\_tensor.py", line 255, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File "C:\Users\Maxim\Documents\Uni\Bachelorarbeit\impress\venv\lib\site-packages\torch\autograd\__init__.py", line 147, in backward
Variable._execution_engine.run_backward(
RuntimeError: Function MmBackward returned an invalid gradient at index 0 - got [2048, 256] but expected shape compatible with [2048, 265]
Process finished with exit code 0
But I can’t see where the error happens 