I am training an inpainting GAN model. I get the error RuntimeError: Could not infer dtype of slice
when trying to utilize array slices for selecting parts of an image. I don’t understand why this error occurs.
My batch data is composed of 4 items: a tensor of images, a tensor of masked images, a tensor of mask patches, and a list of “mask slices”. Each slice denotes a region in the image that the mask patch was cut out from.
Each slice is generated in my dataset’s __getitem__
method. Using numpy, a single “mask slice” is just a tuple of ordinary slice objects:
>>> mask_slice = np.s_[:, row:col, row:col]
>>> mask_slice
(slice(None, None, None), slice(31, 47, None), slice(31, 47, None))
I use a customized collate_fn
to collect all slices of a batch together into a list of slices:
def custom_collate_fn():
...
for i, (img, masked_img, true_mask_patch, mask_slice) in enumerate(batch_list):
mask_slices_batch[i] = mask_slice
...
return imgs, masked_imgs, mask_patches, mask_slices_batch
The error occurs in the training loop while actually utilizing the mask slices.
imgs, masked_imgs, mask_patches, mask_slices = batch
# discriminator output should be all 1's except where the mask patch was taken
ideal_discrim_output = torch.ones(batch_size, 1, 256, 256)
for i, mask_slice in enumerate(mask_slices):
ideal_discrim_output[i, mask_slice] = 0 # <----- ERROR
The actual error is RuntimeError: Could not infer dtype of slice
. Is there a reason why this doesn’t work? Or is there a better way to do this? I can’t eliminate the slicing information because I need to do a variety of operations in the training loop using it.