Hi, I am quite new to PyTorch and FastAI, and I am getting this error while using fastai’s unet_learner and according to the traceback, the error is generated at learn.lr_find()
. Can you please let me know how I can fix this? Thanks in advance.
Here’s a part of my code:
class UNIZDataset(Dataset):
def __init__(self, folderlist):
self.c = 5
count = 0
for folder in tqdm(os.listdir(data_path)):
images, labels = self.extract_data(os.path.join(data_path,folder))
tensor_x = torch.from_numpy(np.array(images))
tensor_y = torch.from_numpy(np.array(labels))
if count == 0:
self.x = tensor_x
self.y = tensor_y
count+=1
else:
self.x = torch.cat((self.x,tensor_x), dim = 0)
self.y = torch.cat((self.y,tensor_y), dim = 0)
def extract_data(self, path):
x = []
y = []
data_x = os.path.join(path,'DATA_X')
data_y = os.path.join(path,'DATA_Y')
fileList = glob.glob(os.path.join(data_x,'*.tif'))
for file in sorted(fileList):
x.append(cv2.imread(os.path.join(data_x,file)))
y.append(cv2.imread(os.path.join(data_y,file)))
return (x,y)
def __getitem__(self, index):
image = self.x[index,:,:]
image = image.permute(2,0,1)
label = self.y[index,:,:]
label = label.permute(2,0,1)
return (image,label)
def __len__(self):
return self.x.size()[0]
train_dataset = UNIZDataset(train_files)
valid_dataset = UNIZDataset(valid_files)
databnch = DataBunch.create(train_dataset,valid_dataset, bs = 18)
learn = unet_learner(data=databnch, arch=models.resnet34)
learn.lr_find()