"mat1 and mat2 shapes cannot be multiplied (448x4 and 576x10)" while doing Layer Wise relevance Propagation

Hello,
I am getting the following error message: “mat1 and mat2 shapes cannot be multiplied (448x4 and 576x10)”.

I am trying to do Layer Wise Relevance Propagation according to this tutorial: Layer Wise Relevance Propagation In Pytorch – GiorgioML – Ph.D. Student in Computer Science, MSU

Here is the NN class:

class Cnn(nn.Module):
    def __init__(self):
       super(Cnn,self).__init__()
    
        self.layer1 = nn.Sequential(
         nn.Conv2d(3,16,kernel_size=3, padding=0,stride=2),
         nn.BatchNorm2d(16),
         nn.ReLU(),
         nn.MaxPool2d(2)
       )
    
        self.layer2 = nn.Sequential(
          nn.Conv2d(16,32, kernel_size=3, padding=0, stride=2),
          nn.BatchNorm2d(32),
          nn.ReLU(),
          nn.MaxPool2d(2)
        )
    
        self.layer3 = nn.Sequential(
          nn.Conv2d(32,64, kernel_size=3, padding=0, stride=2),
          nn.BatchNorm2d(64),
          nn.ReLU(),
          nn.MaxPool2d(2)
       )
    
    
        self.fc1 = nn.Linear(3*3*64,10)
        self.fc2 = nn.Linear(10,2)
        self.relu = nn.ReLU()
    
    
    def forward(self,x):
    
       out = self.layer1(x)
       out = self.layer2(out)
       out = self.layer3(out)
       out = out.view(out.size(0),-1)
       out = self.relu(self.fc1(out))
       out = self.fc2(out)
       return out

I have used this to train a network to do binary classification, and am now trying to run Layer Wise Relevance Propagation on it. Here is the part of the LRP code where I am getting the error:

def LRP_individual(model, X, device):
   # Get the list of layers of the network
   layers = [module for module in model.modules() if not isinstance(module, torch.nn.Sequential)][1:]

   # Propagate the input
   L = len(layers)
   A = [X] + [X] * L # Create a list to store the activation produced by each layer

   for layer in range(L):
       A[layer + 1] = layers[layer].forward(A[layer])
    
  # Rest of the LRP function

I would be very grateful for an explanation or solution to this problem.