Using scikit-learn's scalers for torchvision

Since you are working inplace on ch, you don’t need the second multiplication with scale in your custom implementation. ch.min() will give you the new minimal value, which doesn’t need to be scaled again.

Also, you would need to get the max and min values in dim0 as done in the sklearn implementation.

This implementation should work:

class PyTMinMaxScaler(object):
    """
    Transforms each channel to the range [0, 1].
    """    
    def __call__(self, tensor):
        for ch in tensor:
            scale = 1.0 / (ch.max(dim=0)[0] - ch.min(dim=0)[0])        
            ch.mul_(scale).sub_(ch.min(dim=0)[0])        
        return tensor

However, the loop will slow down your code.
To get the most out of PyTorch, you should use vectorized code:

class PyTMinMaxScalerVectorized(object):
    """
    Transforms each channel to the range [0, 1].
    """
    def __call__(self, tensor):
        scale = 1.0 / (tensor.max(dim=1, keepdim=True)[0] - tensor.min(dim=1, keepdim=True)[0]) 
        tensor.mul_(scale).sub_(tensor.min(dim=1, keepdim=True)[0])
        return tensor

Let’s check, if we get the same values:

img1 = torch.randn(6, 100, 100)
img2 = img1.clone()
img3 = img1.clone()

# sklearn
scaler = MinMaxScaler()
for i in range(img1.size()[0]):
    img1[i] = torch.tensor(scaler.fit_transform(img1[i]))

# PyTorch manual
scaler = PyTMinMaxScaler()
out2 = scaler(img2)

# PyTorch fast
scaler_fast = PyTMinMaxScalerVectorized()
out3 = scaler_fast(img3)

print((img1 - out2).abs().max())
> tensor(1.1921e-07)
print((img1 - out3).abs().max())
> tensor(1.1921e-07)
print((out2 == out3).all())
> tensor(True)

That looks good! The small differences are due to the limited floating point precision.

Let’s see, how fast the PyTorch versions are on the CPU via %timeit (on my old laptop):

%timeit scaler(img2)
> 1.85 ms ± 208 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit scaler_fast(img3)
> 529 µs ± 44.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

If you are using a modern CPU, your code should be way faster. :wink: