Hi
I have tried to implement own my Dataset & DataLoader for my neural net, but I encounterd ‘NotImplementError’.
Below is my code:
class GenDataset(Dataset):
def __init__(self, pddataset, args):
super(GenDataset, self).__init__()
self.pddataset = pddataset
if torch.cuda.is_available():
torch.set_default_tensor_type('torch.cuda.FloatTensor')
else:
torch.set_default_tensor_type('torch.FloatTensor')
def __len__(self):
return len(self.pddataset)
def _getitem__(self, index):
return self.pddataset[index]
And Below is my PyTorch DataLoader code that I implemented.
import os, sys
sys.path.append('/home/byun/PROject/DEEPSEEKER/')
import argparse
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from FUNC.GenDataset import GenDataset
def DATAset_partition(args):
########################################
#
# Generate setting variables
#
########################################
pardir = args.pardir
p_dir = os.getcwd()
Data_path = f'{pardir}/DATA'
df = pd.read_csv(f'{pardir}/CSV/All_HistoData-1.csv')
print(f'# The loaded Data shape: {df.shape}')
train_set = df.iloc[: n_train, :]
val_set = df.iloc[: n_val, :]
test_set = df.iloc[: n_test, :]
print(f'# The train: {train_set.shape}')
print(f'# The validation: {val_set.shape}')
print(f'# The test: {test_set.shape}')
########################################
#
# Load Data
#
########################################
train_loader = GenDataset(train_set, args)
valid_loader = GenDataset(val_set, args)
test_loader = GenDataset(test_set, args)
partition = {'train': train_loader,
'val': valid_loader,
'test': test_loader}
return partition
if __name__ == '__main__':
parser = argparse.ArgumentParser()
args = parser.parse_args("")
args.dataset_mode = 'train'
args.num_total = 1000
args.val_size = 0.2
args.test_size = 0.1
args.train_batch_size = 50
args.pardir = '/home/byun/PROject/DEEPSEEKER'
partition = DATAset_partition(args)
trainloader = DataLoader(partition['train'], batch_size=args.train_batch_size, shuffle=True)
for idx, data in enumerate(trainloader):
print(data)
And the picture below is the error I met.
As I searched, the ‘NotImplementedError’ is due to indentation. But I couldn’t find any indentation typo.
Please let me know what am i miss…
Thank you for your reply in advance!