For each training image, I want to randomly crop four 224*224*3 patches and together with their mirrored ones compose a 8 * 224 * 224 * 3 minibatch. Certainly, the mask is processed in the same way.
I implement a class dataset, and I concat 8 patches in it's function '__getitem__' and return it. In the dataloader, I get the tensor [ 1 * 8 * 224 * 224 * 3], so i need to Squeeze it. But after that, I find that all the patches are changed which are not fliped. Meanwhile, I find that the RandomCrop is not a 'Really Random'.
This is my code.
class dataset(data.Dataset):
def __init__(self, img_path, mask_path):
self.img_path = img_path
self.mask_path = mask_path
def transform(self, image, mask):
image_res, mask_res = [], []
for i in range(1):
i, j, h, w = transforms.RandomCrop.get_params(
image, output_size=(224, 224)
)
print(i, j, h, w)
image = tvf.crop(image, i, j, h, w)
mask = tvf.crop(mask, i, j, h, w)
image_flip = tvf.hflip(image)
#image.show()
#image_flip.show()
mask_flip = tvf.hflip(mask)
image_res.append(np.array(image))
image_res.append(np.array(image_flip))
mask_res.append(np.array(mask))
mask_res.append(np.array(mask_flip))
return np.array(image_res), np.array(mask_res)
def __getitem__(self, item):
image = Image.open(self.img_path)
mask = Image.open(self.mask_path)
print(image.size, mask.size)
image_res, mask_res = self.transform(image, mask)
for i in range(image_res.shape[0]):
img = Image.fromarray(image_res[i])
#img.show()
#totensor = transforms.ToTensor()
#image_res = totensor(image_res)
#mask_res = totensor(mask_res)
return torch.tensor(image_res), torch.tensor(mask_res)
def __len__(self):
return 1
for i, (x, y) in enumerate(dataloader):
print(x.size(), y.size(), '--')
img = Image.fromarray(x[0, 0].numpy())
# print(type(img))
img.show()
img = Image.fromarray(x[0, 1].numpy())
# print(type(img))
img.show()
x = torch.randn(1, 2, 2, 2)
print(x[0])
x = torch.squeeze(x, dim=0)
y = torch.squeeze(y, dim=0)
print(x)
for j in range(x.size()[0]):
tx = x[i]
ty = y[i]
img = Image.fromarray(tx.numpy())
#print(type(img))
#img.show()