I am trying to reconstruct a rather complicated neural network for shape & appearance disentanglement. It includes several transformations on the tensors during the process. Now I am wondering, whether and how to correctly send data on GPU in order to be efficient without taking away useful memory space.
The first questions arises within the DataSet class. It currently looks as follows:
class ImageDataset(Dataset):
def __init__(self, images, arg):
super(ImageDataset, self).__init__()
self.device = arg.device
self.bn = arg.bn
self.brightness = arg.brightness_var
self.contrast = arg.contrast_var
self.saturation = arg.saturation_var
self.hue = arg.hue_var
self.scal = arg.scal
self.tps_scal = arg.tps_scal
self.rot_scal = arg.rot_scal
self.off_scal = arg.off_scal
self.scal_var = arg.scal_var
self.augm_scal = arg.augm_scal
self.images = images
self.transforms = transforms.Compose([transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
def __len__(self):
return len(self.images)
def __getitem__(self, index):
# Select Image
image = self.images[index]
# Get parameters for transformations
tps_param_dic = tps_parameters(1, self.scal, self.tps_scal, self.rot_scal, self.off_scal,
self.scal_var, self.augm_scal)
coord, vector = make_input_tps_param(tps_param_dic)
# Make transformations
x_spatial_transform = self.transforms(image).unsqueeze(0).to(self.device)
x_spatial_transform, t_mesh = ThinPlateSpline(x_spatial_transform, coord,
vector, 128, self.device)
x_spatial_transform = x_spatial_transform.squeeze(0)
x_appearance_transform = K.ColorJitter(self.brightness, self.contrast, self.saturation, self.hue)\
(self.transforms(image).unsqueeze(0)).squeeze(0)
original = self.transforms(image)
coord, vector = coord[0], vector[0]
return original, x_spatial_transform, x_appearance_transform, coord, vector
The ThinPlateSpline function is a rather complicated function, that performs a TPS transformation. During the process, some tensors are created and since I need the function later again, I have to specify the device. As an example, it contains things like that:
def ThinPlateSpline(U, coord, vector, out_size, device, move=None, scal=None):
coord = torch.flip(coord, [2])
vector = torch.flip(vector, [2])
num_batch, channels, height, width = U.shape
out_height = out_size
out_width = out_size
height_f = torch.tensor([height], dtype=torch.float32).to(device)
width_f = torch.tensor([width], dtype=torch.float32).to(device)
num_point = coord.shape[1]
The tensors height_f
and width_f
are therefore created at each call of the function and I wonder if this is a problem? Is there a better way to perform operations on my data within the architecture?
Also, should i send the data to the GPU within the DataSet class?
Thanks for your help!