ValueError: Expected input batch_size (4) to match target batch_size (1)

Hello, I am trying to implement a simple attention module to classify images. But when I implemented that module I am getting batch size error. How can I solve that?

Here how I load my dataset:

X_train, X_test = train_test_split(dataset ,test_size=0.4, random_state=42)

train_loader = DataLoader(X_train)
test_loader = DataLoader(X_test)

My attention module:

class AttentionModule(nn.Module):
    def __init__(self, input_dim):
        super(AttentionModule, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 1)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = torch.sigmoid(self.fc3(x))
        
        return x.view(x.size(0), -1)

My classifier:

class ImageClassifier(nn.Module):
    def __init__(self):
        super(ImageClassifier, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(64 * 32 * 32, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)
        self.attention = AttentionModule(input_dim=64 * 32 * 32)
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(-1, 64 * 32 * 32)
        attention_weights = self.attention(x)
        x = x * attention_weights
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

Error:

ValueError: Expected input batch_size (4) to match target batch_size (1)

Replace:

x = x.view(-1, 64 * 32 * 32)

with:

x = x.view(x.size(0), -1)

to keep the batch size equal and fix shape mismatch errors in the next layers, if needed.