hi, I’m trying to make patch (32,32) for every step 4 from a (256,256) image, it means that the patches are overlapped, I’ve got success when not feed the image to the Dataloader, but when I use Datalodaer to load the image, it shows, the function not work at all. the length of mr_patches return to 0.
I don’t know why, could anyone help me?
Any help would be greatly appreciated.
here is my code to get patches:
def ext_patches_2d_ovl(imgs):
patch_h=32
patch_w=32
stride_h=4
stride_w=4
p = []
img_h = imgs.shape[1] #height of the full image
img_w = imgs.shape[0] #width of the full image
i = 0
for w in range((img_w-patch_w)//stride_w+1): #extract by row
for h in range((img_h-patch_h)//stride_h+1):
patch = imgs[(w*stride_w) : (w*stride_w + patch_w), (h*stride_h):(h*stride_h + patch_h)]
patches= patch
i +=1 #total
p.append(patches)
# print(p[0].shape)
return p
test_mri = os.path.join(test_dir, "mri_data")
for root, dir, file in os.walk(test_mri):
for name in file:
test_data = MyDatasetPaired2d(data_dir=os.path.join(root,name))
test_loader = DataLoader(dataset=test_data, batch_size=1)
with torch.no_grad():
for iter, input_mr_slices in enumerate(test_loader):
mr_patches = ext_patches_2d_ovl(input_mr_slices)
print(len(mr_patches))
class MyDatasetPaired2d(Dataset):
def __init__(self, data_dir, transform=None): #???
self.data_info_mri= self.get_img_info(data_dir)
self.transform = transform
def __getitem__(self, index): # ??index ????paired??
mri_img = self.data_info_mri[index,:,:]
mri_img = mri_img[np.newaxis,:,:]
return mri_img
def __len__(self):
print(len(self.data_info_mri))
return len(self.data_info_mri)
@staticmethod
def get_img_info(data_dir):
mri_data = read_img(data_dir)
return mri_data