torch.utils.data.Dataset.to()

Hi,
Is there a way which I can apply to() function to a Dataset then to() will be applied to the data in Dataset ?
It is a lot more convenient if I can make a nn.Module to() some device meanwhile the data nn.Module used.

dataset is user defined, so no.

I know I can write function to() in the definition of dataset class. I mean I hope the recursive call of nn.Module.to() will be automatically apply to the dataset.to()

But a module has no way to know the dataset object used…

1 Like

what if used multiple inheritance: myDataSet(nn.Module, Dataset) and overwrite the to() method?

Dataset and module serve different purposes. I don’t think it makes sense to merge them together. Moreover, most of the time you want dataset to return CPU tensors and use multiprocessing dataloader and then convert the fetched data to GPU tensors before using them with the model. So I would suggest writing a custom wrapping class of a dataloader and a module, and when it is asked for a dataloader iterator, return something like

def iter_with_convert(self):
  for tensor in self.dataloader:
    yield tensor.to(self.device)