Hello, I am trying to implement a simple attention module to classify images. But when I implemented that module I am getting batch size error. How can I solve that?
Here how I load my dataset:
X_train, X_test = train_test_split(dataset ,test_size=0.4, random_state=42)
train_loader = DataLoader(X_train)
test_loader = DataLoader(X_test)
My attention module:
class AttentionModule(nn.Module):
def __init__(self, input_dim):
super(AttentionModule, self).__init__()
self.fc1 = nn.Linear(input_dim, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 1)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = torch.sigmoid(self.fc3(x))
return x.view(x.size(0), -1)
My classifier:
class ImageClassifier(nn.Module):
def __init__(self):
super(ImageClassifier, self).__init__()
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.fc1 = nn.Linear(64 * 32 * 32, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 10)
self.attention = AttentionModule(input_dim=64 * 32 * 32)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = x.view(-1, 64 * 32 * 32)
attention_weights = self.attention(x)
x = x * attention_weights
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
Error:
ValueError: Expected input batch_size (4) to match target batch_size (1)