I am training an inpainting GAN model. I get the error
RuntimeError: Could not infer dtype of slice when trying to utilize array slices for selecting parts of an image. I don’t understand why this error occurs.
My batch data is composed of 4 items: a tensor of images, a tensor of masked images, a tensor of mask patches, and a list of “mask slices”. Each slice denotes a region in the image that the mask patch was cut out from.
Each slice is generated in my dataset’s
__getitem__ method. Using numpy, a single “mask slice” is just a tuple of ordinary slice objects:
>>> mask_slice = np.s_[:, row:col, row:col] >>> mask_slice (slice(None, None, None), slice(31, 47, None), slice(31, 47, None))
I use a customized
collate_fn to collect all slices of a batch together into a list of slices:
def custom_collate_fn(): ... for i, (img, masked_img, true_mask_patch, mask_slice) in enumerate(batch_list): mask_slices_batch[i] = mask_slice ... return imgs, masked_imgs, mask_patches, mask_slices_batch
The error occurs in the training loop while actually utilizing the mask slices.
imgs, masked_imgs, mask_patches, mask_slices = batch # discriminator output should be all 1's except where the mask patch was taken ideal_discrim_output = torch.ones(batch_size, 1, 256, 256) for i, mask_slice in enumerate(mask_slices): ideal_discrim_output[i, mask_slice] = 0 # <----- ERROR
The actual error is
RuntimeError: Could not infer dtype of slice. Is there a reason why this doesn’t work? Or is there a better way to do this? I can’t eliminate the slicing information because I need to do a variety of operations in the training loop using it.