Thanks prtblck!
I did as you have advised, I am getting error, which I could not figure out by googling, so let me ask in this post.
I have my numpy arrays in the shape (1, 42, 58), as mentioned in two folders with identical names for each pair.
I am getting following error now:
linear(): argument ‘input’ (position 1) must be Tensor, not list
Could you please advise what causing it?
class En_De_coder_dataset(Dataset):
def __init__(self, in_dir, out_dir):
self.in_dir = in_dir
self.out_dir = out_dir
def __len__(self):
return len(os.listdir(self.in_dir))
def __getitem__(self, index):
input = torch.from_numpy(np.load(self.in_dir + '/grid_aug_' + str(index) + '.npy')),float()
output = torch.from_numpy(np.load(self.out_dir + '/grid_aug_' + str(index) + '.npy')),float()
return (input, output)
dataset = En_De_coder_dataset(in_dir = 'augmented_inputs', out_dir = 'augmented_outputs')
train_set, test_set = torch.utils.data.random_split(dataset, [100, 44])
train_loader = DataLoader(dataset=train_set, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_set, batch_size=64, shuffle=True)
class En_De_coder(nn.Module):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(
nn.Linear(28 * 28, 128), # (N, 784) -> (N, 128)
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 12),
nn.ReLU(),
nn.Linear(12, 3) # -> N, 3
)
self.decoder = nn.Sequential(
nn.Linear(3, 12),
nn.ReLU(),
nn.Linear(12, 64),
nn.ReLU(),
nn.Linear(64, 128),
nn.ReLU(),
nn.Linear(128, 28 * 28),
nn.Sigmoid()
)
def forward(self, x):
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return decoded
model = En_De_coder()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(),
lr=1e-3,
weight_decay=1e-5)
num_epochs = 10
outputs = []
for epoch in range(num_epochs):
for batch_idx, (data, targets) in enumerate(train_loader):
recon = model(data)
loss = criterion(recon, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f'Epoch:{epoch+1}, Loss:{loss.item():.4f}')
outputs.append((epoch, targets, recon))
my data looks like this