"Could not infer dtype of slice" during training

I am training an inpainting GAN model. I get the error RuntimeError: Could not infer dtype of slice when trying to utilize array slices for selecting parts of an image. I don’t understand why this error occurs.

My batch data is composed of 4 items: a tensor of images, a tensor of masked images, a tensor of mask patches, and a list of “mask slices”. Each slice denotes a region in the image that the mask patch was cut out from.

Each slice is generated in my dataset’s __getitem__ method. Using numpy, a single “mask slice” is just a tuple of ordinary slice objects:

>>> mask_slice = np.s_[:, row:col, row:col]
>>> mask_slice
(slice(None, None, None), slice(31, 47, None), slice(31, 47, None))

I use a customized collate_fn to collect all slices of a batch together into a list of slices:

def custom_collate_fn():
    for i, (img, masked_img, true_mask_patch, mask_slice) in enumerate(batch_list):
        mask_slices_batch[i] = mask_slice
    return imgs, masked_imgs, mask_patches, mask_slices_batch

The error occurs in the training loop while actually utilizing the mask slices.

imgs, masked_imgs, mask_patches, mask_slices = batch

# discriminator output should be all 1's except where the mask patch was taken
ideal_discrim_output = torch.ones(batch_size, 1, 256, 256)

for i, mask_slice in enumerate(mask_slices):
    ideal_discrim_output[i, mask_slice] = 0    # <----- ERROR

The actual error is RuntimeError: Could not infer dtype of slice. Is there a reason why this doesn’t work? Or is there a better way to do this? I can’t eliminate the slicing information because I need to do a variety of operations in the training loop using it.

It seems tuples or lists of slice objects are unsupported and the error is raised here.
The same seems to apply for numpy, too:

ideal_discrim_output = torch.ones(1, 1, 256, 256)
mask_slice = (slice(None, None, None), slice(31, 47, None), slice(31, 47, None))
ideal_discrim_output[0, [slice(None, None, None), slice(None, None, None)]]
# RuntimeError: Could not infer dtype of slice
ideal_discrim_output[0, slice(None, None, None)] # works

x = np.random.randn(10, 10, 256, 256)
x[0, (slice(None, None, None), slice(None, None, None))]
# IndexError: only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and integer or boolean arrays are valid indices
x[0, slice(None, None, None)] # works

so you might need to either iterate the slices or try to create a single index.

Thanks, you’re right. That’s too bad, because data[i, mask_slice] would be robust and easy to read. Unpacking also doesn’t work, but would be equally nice, like data[i, *mask_slice].

As you suggest, the following does work:

ideal_discrim_output[i, mask_slice[0], mask_slice[1], mask_slice[2]] = 0