import math
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
class MRSI_Dataset(Dataset):
def __init__(self,data,engine):
self.engine = engine
# initialize dataset
self.data = data
self.t = torch.from_numpy(self.engine.t[0:data.shape[1]].T).float()
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
sample = self. Data[idx]
return sample

The problem is that using my customized dataset makes the result worse quantitatively. It’s like there is some reduced precision. I was wondering if there is any difference between TensorDataset and the customized dataset.

TensorDataset will just index all passed tensors and return these as seen here.
Your custom Dataset won’t work since you are indexing an unknown self.Data attribute. Also, what is self.t used for?

Thank you. I modified TensorDataset. here is the working piece of code:

class MRSI_Dataset(Dataset[Tuple[Tensor, ...]]):
r"""Dataset wrapping tensors.
Each sample will be retrieved by indexing tensors along the first dimension.
Args:
*tensors (Tensor): tensors that have the same size of the first dimension.
"""
tensors: Tuple[Tensor, ...]
def __init__(self, *tensors: Tensor, engine ) -> None:
assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors), "Size mismatch between tensors"
self.tensors = tensors
self.engine = engine
self.t = torch.from_numpy(self.engine.t[0:tensors[0].shape[1]].T).float()
def __getitem__(self, index):
return tuple(self.get_augment(tensor[index]) for tensor in self. Tensors)
def __len__(self):
return self. Tensors[0].size(0)

I am so sorry for not giving a full description, self.t is a time vector for the augmentation process.(sampling).

Before getting your answer, I found a workaround to use vmap for augmentation in the training step as follows:

Apparently, it is faster. I appreciate it if you comment on this approach. I’m wondering whether it is a standard method or if I should use a custom dataset.

I don’t fully understand the indexing as it seems you are indexing each tensor with the passed index while I would assume the self.Tensors object would be indexed. Could you explain what exactly is stored in self.Tensors?

Thank you.
self. Tensor is a 2D matrix (m times n) in which each row is an array (m signals with the length of n).
When I modified TensorDataset, it worked. Does tuple(…) make difference?