Custom dataset raising not implemented errors on __getitem__

Hello!

I’m not trying to do anything fancy, so I’m quite confused about what’s going wrong. I’m loading time series data from a csv into a custom Dataset, but getitem is raising NotImpementedErrors. Here’s the Dataset creation code:

import torch
import torch.nn as nn
import numpy as np

# Data Loading stuff
import pandas as pd
from torch.utils.data import DataLoader, Dataset




class WeatherData(Dataset):
    def __init__(self, csv_file, window, transform = None):
        self.data = pd.read_csv(csv_file)
        self.window = window
        self.transform = transform
        
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        x, y = (self.data.iloc[idx:idx+self.window, 1:].values, self.data.iloc[idx+1:idx+self.window+1, 1:].values)
        if self.transform != None:
            x, y = self.transform(x), self.transform(y)
        return x, y
    
    def __len__(self):
        return len(self.data) - self.window - 1
    
class ToTensor(object):
    def __call__(self, sample):
        return torch.tensor(sample, dtype = torch.long)

I’m then running this code to test the dataset:

mydataset = WeatherData('data/weather.csv', 30, transform=ToTensor)

l = len(mydataset)
print(l)

for i, s in enumerate(mydataset):
    print(i, s)
    if i ==2:
        break

It gives the following output/error message:

52665

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Input In [41], in <cell line: 12>()
      9 l = len(mydataset)
     10 print(l)
---> 12 for i, s in enumerate(mydataset):
     13     print(i, s)
     14     if i ==2:

File ~/.local/lib/python3.10/site-packages/torch/utils/data/dataset.py:53, in Dataset.__getitem__(self, index)
     52 def __getitem__(self, index) -> T_co:
---> 53     raise NotImplementedError

NotImplementedError: 

So as far as I can tell, the data loading is fine, but something about my __getitem__ function is broken. Thanks for any help!

Update: I was able to get my code to work, though I don’t understand what was wrong to begin with. Before, I defined WeatherData in its own .py file, and imported it with

from weatherdataloading import WeatherData

Indeed, any data class I define in a separate file and try to import seems to give me the same error. However, creating the dataloading class in the main file and otherwise running without change seems to work fine.