is it the network structure or the code itself?
the code takes a random image from a random directory, determines a target and does the train thing
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
import random
import os
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"device: {device}")
class Network(nn.Module):
def __init__(self):
super(Network, self).__init__()
self.L1 = nn.Linear(28*28, 1000)
self.L2 = nn.Linear(1000, 1500)
self.L3 = nn.Linear(1500, 1000)
self.L4 = nn.Linear(1000, 500)
self.L5 = nn.Linear(500, 250)
self.L6 = nn.Linear(250, 100)
self.L7 = nn.Linear(100, 50)
self.L8 = nn.Linear(50, 2)
def forward(self, x):
x = torch.sigmoid(self.L1(x))
x = torch.sigmoid(self.L2(x))
x = torch.sigmoid(self.L3(x))
x = torch.sigmoid(self.L4(x))
x = torch.sigmoid(self.L5(x))
x = torch.sigmoid(self.L6(x))
x = torch.sigmoid(self.L7(x))
x = self.L8(x)
return x
net = Network().to(device)
optimizer = optim.SGD(net.parameters(), lr=0.01)
criterion = nn.MSELoss()
epochs = 3000
loss_hist = []
net.train() # puts the network in "training mode" (idk)
for epoch in tqdm(range(0, epochs)):
directory = os.path.join("data", random.choice(os.listdir("data")))
imagePath = os.path.join(directory, random.choice(os.listdir(directory))) # get a random image file from a random directory
target = torch.tensor([0., 1.]) if imagePath[5:6] == "0" else torch.tensor([1., 0.]) # determine target using the file's parent directory
img = Image.open(imagePath)
img = transforms.functional.to_grayscale(img, 1) # convert image to grayscale
'''
plt.imshow(img)
plt.title(imagePath[5:6])
plt.show()
'''
imageTensor = transforms.ToTensor()(img) # convert the image to tensor
imageTensor = imageTensor.squeeze() # from size(1, 28, 28) to size(28, 28)
imageTensor = torch.flatten(imageTensor) # from size(28, 28) to size(784)
input = imageTensor
optimizer.zero_grad()
out = net(input)
loss = criterion(out, target)
loss.backward()
optimizer.step()
loss_hist.append(loss.item())
plt.plot(loss_hist)
plt.show()