Sorry for my lack of understanding. My model accepts single inputs. Since my ConcatDataset return a tuple where it contains the tensor and label, I still confuse whether I need to make changes to the Concat Dataset() or make change in my model to form a batch. I look for collete_fn function but I can’t find it (torch.utils.data).
class our_AE(nn.Module):
def init(self):
super(our_AE, self).init()
self.encoder = nn.Sequential(
nn.Conv2d(1, 16, 3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(16, 32, 3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(32, 64, 7)
)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(64, 32, 7),
nn.ReLU(),
nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1),
nn.ReLU(),
nn.ConvTranspose2d(16, 1, 3, stride=2, padding=1, output_padding=1),
nn.Sigmoid()
)
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
AE = our_AE().to(device)
optimizer = optim.Adam(AE.parameters(), lr=1e-4)
loss_fn = nn.MSELoss(reduction=‘sum’)
def train(epoch, device):
AE.train()
for batch_idx, (images, _) in enumerate(train_loader):
optimizer.zero_grad()
images = images.to(device)
output = AE(images)
loss = loss_fn(output, images) # Here is a typical loss function (Mean square error)
loss.backward()
optimizer.step()
if batch_idx % 10 == 0: # We record our output every 10 batches
train_losses.append(loss.item()/batch_size_train) # item() is to get the value of the tensor directly
train_counter.append(
(batch_idx*64) + ((epoch-1)len(train_loader.dataset)))
if batch_idx % 100 == 0: # We visulize our output every 100 batches
print(f’Epoch {epoch}: [{batch_idxlen(images)}/{len(train_loader.dataset)}] Loss: {loss.item()/batch_size_train}’)
def test(epoch, device):
AE.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for images, _ in test_loader:
images = images.to(device)
output = AE(images)
test_loss += loss_fn(output, images).item()
test_loss /= len(test_loader.dataset)
test_losses.append(test_loss)
test_counter.append(len(train_loader.dataset)*epoch)
print(f’Test result on epoch {epoch}: Avg loss is {test_loss}’)
train_losses = []
train_counter = []
test_losses = []
test_counter = []
max_epoch = 3
for epoch in range(1, max_epoch+1):
train(epoch, device=device)
test(epoch, device=device)