Len() for Iterabledataset

Hey all,

I am using IterableDataset as my data’s length are not uniform and it varies a lot. Now I am trying to implement a validation step in my training loop which requires the length of the dataset, but as it iterable, len() cannot be used.

def val(self, epoch):
    test_loss = 0
    correct = 0

    with torch.no_grad():
        for data, target in self.val_loader:
            data, target = data.to(self.device), target.to(self.device)
            predictions = self.model(data.float())
            correct += (prediction == target).sum().item()
            total += target.size(0)
            val_loss += self.criterion(predictions.float(), target.float()).item()

        test_loss /= len(self.test_loader.dataset)
        accuracy = 100. * correct / len(self.test_loader.dataset)

        print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
            test_loss, correct, len(self.test_loader.dataset), accuracy))

        with val_summary_writer.as_default():
            summary.scalar('val_loss', val_loss.item(), step=self.globaliter)
            summary.scalar('accuracy', accuracy, step=self.globaliter)

This is my validation function.

Is there a way to get len() for iterable dataset?


I think your approach of storing the number of samples in total is valid for your use case.
Instead of dividing by len(self.test_loader.dataset), just divide by total.