Could anyone help with the following error in this code ?
len(x) = 2
x[0].shape = torch.Size([1024, 1, 512])
x[1].shape = torch.Size([1, 1024, 1024])
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-47-1eb7eeebb607> in <cell line: 51>()
55 optimizer.zero_grad()
56 inputs = torch.randint(low=0, high=1, size=(1024,1)).to(torch.long)
---> 57 logits = model(inputs)
58 labels = torch.randint(low=0, high=1, size=(1,))
59 loss = loss_fn(logits, labels)
3 frames
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
<ipython-input-47-1eb7eeebb607> in forward(self, x)
38 x = torch.cat((x[0], x[1]), dim=-1)
39
---> 40 x = self.fc(x)
41
42 return x
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py in forward(self, input)
112
113 def forward(self, input: Tensor) -> Tensor:
--> 114 return F.linear(input, self.weight, self.bias)
115
116 def extra_repr(self) -> str:
RuntimeError: mat1 and mat2 shapes cannot be multiplied (1024x1536 and 512x4)