Why is the dataloder didn't load patches?

hi, I’m trying to make patch (32,32) for every step 4 from a (256,256) image, it means that the patches are overlapped, I’ve got success when not feed the image to the Dataloader, but when I use Datalodaer to load the image, it shows, the function not work at all. the length of mr_patches return to 0.
I don’t know why, could anyone help me?

Any help would be greatly appreciated.
here is my code to get patches:

def ext_patches_2d_ovl(imgs):
    p = []
    img_h = imgs.shape[1]  #height of the full image
    img_w = imgs.shape[0] #width of the full image
    i = 0
    for w in range((img_w-patch_w)//stride_w+1): #extract by row
        for h in range((img_h-patch_h)//stride_h+1):
            patch = imgs[(w*stride_w) : (w*stride_w + patch_w), (h*stride_h):(h*stride_h + patch_h)]
            patches= patch
            i +=1   #total
    # print(p[0].shape)
    return p

test_mri = os.path.join(test_dir, "mri_data")
for root, dir, file in os.walk(test_mri):
    for name in file:
        test_data = MyDatasetPaired2d(data_dir=os.path.join(root,name))
        test_loader = DataLoader(dataset=test_data, batch_size=1)
        with torch.no_grad():
            for iter, input_mr_slices in enumerate(test_loader):   
                mr_patches = ext_patches_2d_ovl(input_mr_slices)

class MyDatasetPaired2d(Dataset):

    def __init__(self, data_dir, transform=None):  #???
        self.data_info_mri= self.get_img_info(data_dir)
        self.transform = transform

    def __getitem__(self, index): # ??index ????paired??
        mri_img = self.data_info_mri[index,:,:]
        mri_img = mri_img[np.newaxis,:,:]
        return mri_img
    def __len__(self):
        return len(self.data_info_mri)
    def get_img_info(data_dir):
        mri_data = read_img(data_dir)
        return mri_data

Your DataLoader will return a batch of samples and thus add a batch dimension to it.
The height and width calculation in ext_patches_2d_ovl will thus be wrong:

img_h = imgs.shape[1]  #height of the full image
img_w = imgs.shape[0] #width of the full image

and you might need to change it to index2 and index1.

thanks soooo much!! it works! :smiley:

Hi ptrblck, can you help me? I post a new question for the multi-label classification.
I try it but really hard to understand what is my problem