Hey all,
I am using IterableDataset as my data’s length are not uniform and it varies a lot. Now I am trying to implement a validation step in my training loop which requires the length of the dataset, but as it iterable, len() cannot be used.
def val(self, epoch):
test_loss = 0
correct = 0
with torch.no_grad():
self.model.eval()
for data, target in self.val_loader:
data, target = data.to(self.device), target.to(self.device)
predictions = self.model(data.float())
pdb.set_trace()
correct += (prediction == target).sum().item()
total += target.size(0)
val_loss += self.criterion(predictions.float(), target.float()).item()
test_loss /= len(self.test_loader.dataset)
accuracy = 100. * correct / len(self.test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(self.test_loader.dataset), accuracy))
with val_summary_writer.as_default():
summary.scalar('val_loss', val_loss.item(), step=self.globaliter)
summary.scalar('accuracy', accuracy, step=self.globaliter)
This is my validation function.
Is there a way to get len() for iterable dataset?
Thanks!