I am implementing cycle Gans .
My dataloader is written as : class MyDataset(Dataset) :
def init(self , patch_CT, patch_PET):
‘characterizes a dataset for pytorch’
self.patch_CT = patch_CT
self.patch_PET= patch_PET
self.transforms = transforms.ToTensor
def __len__(self):
'denotes the total number of samples'
return len(self.patch_CT)
print(len(self.patch_PET))
def __getitem__(self,index):
'Generates one sample of data'
#select sample
x = self.patch_CT[index]
y = self.patch_PET[index]
# Unsqueeze channel dimension
#x = x.unsqueeze(0)
#y = y.unsqueeze(0)
return x,y
train_A_dataset, test_A_dataset = torch.utils.data.random_split(
patch_CT ,[int(0.7len(patch_CT)), int(len(patch_CT) - int(0.7len(patch_CT)))])
train_B_dataset, test_B_dataset = torch.utils.data.random_split(
patch_PET ,[int(0.7len(patch_PET)), int(len(patch_PET) - int(0.7len(patch_PET)))])
and for dataloader:
a_loader = torch.utils.data.DataLoader(train_A_dataset, batch_size=5, shuffle=False
b_loader = torch.utils.data.DataLoader(train_B_dataset, batch_size=5, shuffle=False)
for i,batch in enumerate((a_loader, b_loader)):
# step
step = epoch * min(len(a_loader), len(b_loader)) + i + 1
# Generator Computations
##################################################
set_grad([self.Da, self.Db], False)
self.g_optimizer.zero_grad()
a_real = Variable(batch[0])
b_real = Variable(batch[0])
a_real, b_real = utils.cuda([a_real, b_real])
The error I get is
a_real = Variable(batch[0])
TypeError: ‘DataLoader’ object does not support indexing
How can I solve this?
Thank you in advance.