Hey everyone
I’m trying to train a network to do an image processing operation. Since I was having some trouble with that, I moved to just trying to train a network to do nothing, and it still doesn’t converge to anything good.
Does anything looks off in my code? Any idea/remark will be gladly appreciated
Thanks!
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
self.conv3 = nn.Conv2d(128, 256, 3, padding=1)
self.conv4 = nn.Conv2d(256, 256, 3, padding=1)
self.conv5 = nn.Conv2d(256, 256, 3, padding=1)
self.conv6 = nn.Conv2d(256, 128, 3, padding=1)
self.conv7 = nn.Conv2d(128, 64, 3, padding=1)
self.conv8 = nn.Conv2d(64, 3, 3, padding=1)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = F.relu(self.conv4(x))
x = F.relu(self.conv5(x))
x = F.relu(self.conv6(x))
x = F.relu(self.conv7(x))
x = F.relu(self.conv8(x))
return x
class SingleImageDataset(Dataset):
def __init__(self, path):
self.image = Image.open(path)
tensor_trans = transforms.ToTensor()
tens = tensor_trans(self.image)
self.stds = [tens[i].std().item() for i in range(3)]
self.means = [tens[i].mean().item() for i in range(3)]
normalize = transforms.Normalize(mean=self.means, std=self.stds)
self.final_data = normalize(tens)
def __len__(self):
return 1
def __getitem__(self, idx):
return self.final_data
naruto_dataset = SingleImageDataset('./image/naruto.jpg')
naruto_loader = torch.utils.data.DataLoader(naruto_dataset, batch_size=1,
shuffle=False, num_workers=1)
net = Net()
net.to(device)
criterion = nn.L1Loss()
optimizer = optim.Adam(net.parameters(), lr=0.001)
losses = []
for epoch in tqdm.tqdm(range(500)):
running_loss = 0.0
for i, data in enumerate(naruto_loader, 0):
inputs = data.cuda()
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, inputs)
losses.append(loss.item())
# print(loss.item())
loss.backward()
optimizer.step()
And the loss converges on 0.4 (which is horrible, since I’m using L1Loss which is a mean)