I have been stuck with the following error:
RuntimeError: Given groups=1, weight of size [3, 1, 5, 5], expected input[16, 3, 50, 50] to have 1 channels, but got 3 channels instead
Here is my code:
class MyDataset(Dataset):
def __init__(self, csv_file, root_dir, transform=None):
self.annotations = pd.read_csv(csv_file)
self.root_dir = root_dir
self.transform = transform
def __len__(self):
return len(self.annotations) # +- 500
def __getitem__(self, index):
img_path = os.path.join(self.root_dir, self.annotations.iloc[index, 0])
image = io.imread(img_path)
PIL_image = Image.fromarray(image)
y_label = torch.tensor(int(self.annotations.iloc[index, 1]))
if self.transform:
image = self.transform(PIL_image)
return (image, y_label)
class ConvNet(torch.nn.Module):
def __init__(self):
super(ConvNet, self).__init__()
f2 = 3
self.layer2 = nn.Sequential(
nn.Conv2d(1, f2, kernel_size=5, stride=1, padding=2),
nn.ReLU(),
nn.BatchNorm2d(f2),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.fc1 = nn.Linear(50 * 50 * f2, 200)
self.fc2 = nn.Linear(200, 20)
self.fc3 = nn.Linear(20, 1)
def forward(self, x):
x = self.layer2(x)
x = x.reshape(x.size(0), -1)
x = self.fc1(x)
x = self.fc2(x)
x = self.fc3(x)
return x
dataset = MyDataset(
csv_file='dataset.csv',
root_dir='tmp',
transform=transforms.Compose([
transforms.Resize(50),
transforms.ToTensor()
])
)
train_set, test_set = torch.utils.data.random_split(dataset, lengths=[500, 70])
train_loader = DataLoader(dataset=train_set, batch_size=16, shuffle=True)
test_loader = DataLoader(dataset=test_set, batch_size=16, shuffle=True)
model = ConvNet()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
for epoch in range(20):
losses = []
for batch_idx, (data, targets) in enumerate(train_loader):
data = data.to(device=device)
targets = targets.to(device=device)
# forward
scores = model(data)
loss = criterion(scores, targets)
losses.append(loss.item())
# backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('Cost: {0} = {1}'.format(epoch, sum(losses)/len(losses)))
I am looking for help to debug this. Thanks