def __init__(self, roll=False):
self.roll = roll
def __call__(self, img_group):
if img_group[0].mode == 'L':
return np.concatenate([np.expand_dims(x, 2) for x in img_group], axis=2)
elif img_group[0].mode == 'RGB':
if self.roll:
return np.array([np.array(x)[:, :, ::-1] for x in img_group])
else:
return np.array([np.array(x) for x in img_group])

I don’t think you would necessarily need to transform the inputs first to PIL.Images, as the method seems to return numpy arrays.
Since you are already passing numpy arrays to this method, you would need to check the channel dimension (instead of the mode) and re-write the logic to work on numpy arrays directly.