Dear community,
I’m struggling to correctly load an image dataset for an application. I have image pairs (geometry image and velocity flow image) that will be used to train a model to estimate the velocity flow around a geometry image. The application can be thought as geometry image as the sample and velocity flow image as the label. The training images are being resized into (64, 64) format.
from PIL import Image
import glob
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
# Get image paths
mesh_paths = glob.glob('./mesh_data/*.tif')
vel_paths = glob.glob('./vel_data/*.tif')
# Separate training samples
train_mesh_paths = mesh_paths[:int(total_samples*train_size)]
train_vel_paths = vel_paths[:int(total_samples*train_size)]
Then I define a custom data_loader class to load the data:
class mesh_vel_dataset(Dataset):
def __init__(self, meshes, veles, train=True):
self.meshes = meshes
self.vels = veles
def transform(self, mesh, vel):
resize_mesh = transforms.Resize(size = (64,64), interpolation=Image.NEAREST)
resize_vel = transforms.Resize(size = (64,64), interpolation=Image.NEAREST)
gray = transforms.Grayscale(num_output_channels=1)
mesh = TF.to_tensor(resize_mesh(mesh))
vel = TF.to_tensor(resize_vel(vel))
vel = vel.__ge__(0.7).type(torch.FloatTensor)
return mesh, vel
def __getitem__(self, idx):
'''
This function actually loads the data
'''
mesh = Image.open(self.meshes[idx])
vel = Image.open(self.vels[idx])
x, y = self.transform(mesh, vel) # mesh and velocity
return x, y
def __len__(self):
return len(self.meshes)
At last, I load the data:
train_data = mesh_vel_dataset(train_mesh_paths, train_vel_paths, train=True)
train_loader = DataLoader(train_data, batch_size = 5, shuffle = True)
When I try to inspect random items from the training_loader using the iter(train_loader)
i get the following error:
dataiter = iter(train_loader)
mesh, vel = dataiter.next()
<ipython-input-4-d806966cea0c> in transform (self, mesh, vel)
6
7 def transform( self, mesh, vel):
-----> 8 resize_mesh = transforms.Resize(size = (64,64), interpolation=Image.NEAREST)
9 resize_vel = transforms.Resize(size = (64,64), interpolation=Image.NEAREST)
AttributeError: 'NoneType' object has no attribute 'Resize'
Can anybory share some thoughts?
Thanks in advance.