It seems tuples or lists of slice
objects are unsupported and the error is raised here.
The same seems to apply for numpy, too:
ideal_discrim_output = torch.ones(1, 1, 256, 256)
mask_slice = (slice(None, None, None), slice(31, 47, None), slice(31, 47, None))
ideal_discrim_output[0, [slice(None, None, None), slice(None, None, None)]]
# RuntimeError: Could not infer dtype of slice
ideal_discrim_output[0, slice(None, None, None)] # works
x = np.random.randn(10, 10, 256, 256)
x[0, (slice(None, None, None), slice(None, None, None))]
# IndexError: only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and integer or boolean arrays are valid indices
x[0, slice(None, None, None)] # works
so you might need to either iterate the slice
s or try to create a single index.