Index problem with the Dataloader with multiple indexes

Hello,

i have a problem with iterating my data. I am also roughly aware of why this is happening, but I have no solution for it.
My task is to pass a dataset of hyperspectral images through a Convolution Auto Encoder. Due to the fact that I am not interested in the local information I want to create a 1D Convolutional Auto encoder. My problem now is to load the latencies in the necessary form.

For this purpose a data set exists in the form.
[18000x91] = [Pixel, Channel]

Now I transform this dataset into the original form of the dataset.
45x91x400] = [Images, channel, pixels in the images].

The code for loading the data is shown below.


class SpectralDataset(Dataset):
    def __init__(self):
        xy = np.loadtxt(Dataset_Path, delimiter=",", dtype=np.float32)
        self.len = xy.shape[0]
        # pylint: disable=E1101 # Fehler in VSC muss hier hinzugefügt werden
        x = torch.from_numpy(xy[:, 2:])
        x = torch.reshape(x, (45, 400, 91))
        self.x = x.permute(0, 2, 1)
        # pylint: enable=E1101 # Fehler in VSC muss hier hinzugefügt werden

    def __getitem__(self, index):
        return self.x[index]

    def __len__(self):
        return self.len
dataset = SpectralDataset()

Since I want to do batch training with shuffled data sets, I also use the Dataloader function of Pytorch. But now it gives me index values, which are far beyond the range of [0 45] .
The Dataloader uses index values from the original dataset, i.e. [0 18000] .

How can you avoid this problem?

Hello,
just found my mistake.

The Dataloader works with the len function.I had connected it with the previous dataform. So my problem is solved.

The code should look like the following.

class SpectralDataset(Dataset):
    def __init__(self):
        xy = np.loadtxt(Dataset_Path, delimiter=",", dtype=np.float32)
        # pylint: disable=E1101 # Fehler in VSC muss hier hinzugefügt werden
        x = torch.from_numpy(xy[:, 2:])
        x = torch.reshape(x, (45, 400, 91))
        self.x = x.permute(0, 2, 1)
        self.len = x.shape[0]
        # pylint: enable=E1101 # Fehler in VSC muss hier hinzugefügt werden

    def __getitem__(self, index):
        return self.x[index]

    def __len__(self):
        return self.len
dataset = SpectralDataset()