How to store and load custom PyTorch 'Dataset'?

I have a dataframe with only one column named ‘address’. It consists of strings of addresses of different places. I am going to feed this data as input to RoBERTa for pretraining on mask language modelling task. So, I am trying to convert the dataset into PyTorch’s Dataset object. It is as follows :

from torch.utils.data.dataset import Dataset
class CustomDataset(Dataset):
    def __init__(self, df, tokenizer):
        # or use the RobertaTokenizer from `transformers` directly.

        self.examples = []
        
        for example in df.values:
            x=tokenizer.encode_plus(example, max_length = MAX_LEN, truncation=True, padding=True)
            self.examples += [x.input_ids]

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, i):
        # We’ll pad at the batch level.
        return torch.tensor(self.examples[i])

# Create the train and evaluation dataset
train_dataset = CustomDataset(train['address'], tokenizer)
eval_dataset = CustomDataset(test['address'], tokenizer)

Everything works fine upto this. Now, I want to store the ‘train_dataset’ and ‘eval_dataset’ so that I can use them later in GPU instance. I tried to do it using torch.save as follows :

torch.save(train_dataset, './train.pt')
torch.save(eval_dataset, './test.pt')

This too works fine. But, when I try to load this saved dataset into new variable in another instance, it throws attribute error. The code I used for loading this dataset is as follows :

train_dataset = torch.load('./train.pt')

It throws error as follows :

AttributeError: Can’t get attribute ‘CustomDataset’ on 'main '>
I think, this error is because of custom dataset that I created. I want to save those datasets to a file so that I can use them in another instance with GPU. Could someone help with the code required to load this data from a ‘.pt’ file ?

What the error is telling you is that it does not know what to do with CustomDataset.

I think if you import this code into your main file it might work.

So it would look something like

from YOUR_FILE_LOCATION import CustomDataset

And then you can load it.