RuntimeError: mat1 and mat2 shapes cannot be multiplied (1024x1536 and 512x4)

Could anyone help with the following error in this code ?

len(x) =  2
x[0].shape =  torch.Size([1024, 1, 512])
x[1].shape =  torch.Size([1, 1024, 1024])

---------------------------------------------------------------------------

RuntimeError                              Traceback (most recent call last)

<ipython-input-47-1eb7eeebb607> in <cell line: 51>()
     55     optimizer.zero_grad()
     56     inputs = torch.randint(low=0, high=1, size=(1024,1)).to(torch.long)
---> 57     logits = model(inputs)
     58     labels = torch.randint(low=0, high=1, size=(1,))
     59     loss = loss_fn(logits, labels)

3 frames

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1499                 or _global_backward_pre_hooks or _global_backward_hooks
   1500                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501             return forward_call(*args, **kwargs)
   1502         # Do not call functions when jit is used
   1503         full_backward_hooks, non_full_backward_hooks = [], []

<ipython-input-47-1eb7eeebb607> in forward(self, x)
     38         x = torch.cat((x[0], x[1]), dim=-1)
     39 
---> 40         x = self.fc(x)
     41 
     42         return x

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1499                 or _global_backward_pre_hooks or _global_backward_hooks
   1500                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501             return forward_call(*args, **kwargs)
   1502         # Do not call functions when jit is used
   1503         full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py in forward(self, input)
    112 
    113     def forward(self, input: Tensor) -> Tensor:
--> 114         return F.linear(input, self.weight, self.bias)
    115 
    116     def extra_repr(self) -> str:

RuntimeError: mat1 and mat2 shapes cannot be multiplied (1024x1536 and 512x4)

The issue is raised in a linear layer as the expected number of input features doesn’t match the feature size of the input activation.
Based on the stacktrace it seems self.fc raises the error so you might need to check its in_features argument and adapt it.