Another alternative based on Sklearn implementation [1, 2]:
class MinMax(torch.nn.Module):
def __init__(self, fit_tensor: torch.Tensor, feature_range = (0, 1), inplace = False):
super().__init__()
self.feature_range = feature_range
self.inplace = inplace
self.is_inverse_ = False
self.fit_(fit_tensor)
def fit_(self, tensor: torch.Tensor):
# Get min and max values for each variable
data_min = tensor.min(dim = 0, keepdim = True)[0]
data_max = tensor.max(dim = 0, keepdim = True)[0]
data_range = data_max - data_min
# Prevent division by zero in line below
data_range[data_range == 0.0] = 1.0
# See the Sklearn implementation for reference
self.scale_ = (self.feature_range[1] - self.feature_range[0]) / data_range
self.min_ = self.feature_range[0] - data_min * self.scale_
self.data_min_ = data_min
self.data_max_ = data_max
self.data_range_ = data_range
def forward(self, tensor: torch.Tensor):
if self.is_inverse_:
return self.backward(tensor)
if not self.inplace:
tensor = tensor.clone()
tensor.mul_(self.scale_)
tensor.add_(self.min_)
return tensor
def backward(self, tensor: torch.Tensor):
if not self.inplace:
tensor = tensor.clone()
tensor.sub_(self.min_)
tensor.div_(self.scale_)
return tensor
def inverse(self):
# Return an inverse copy of this class
clone = copy.deepcopy(self)
clone.is_inverse_ = True
return clone
Seems to work well after a simple testing:
arr = torch.Tensor([
[7,2,3],
[4,5,9],
[1,8,6]
])
# Default way
temp = MinMax(fit_tensor = arr)
fwd_arr = temp(arr)
bwd_arr = temp.backward(fwd_arr)
print(arr)
print(fwd_arr)
print(bwd_arr)
# Alternative way without calling backward
inv_temp = temp.inverse()
bwd_arr = inv_temp(fwd_arr)
print(bwd_arr)
tensor([[7., 2., 3.],
[4., 5., 9.],
[1., 8., 6.]])
tensor([[1.0000, 0.0000, 0.0000],
[0.5000, 0.5000, 1.0000],
[0.0000, 1.0000, 0.5000]])
tensor([[7.0000, 2.0000, 3.0000],
[4.0000, 5.0000, 9.0000],
[1.0000, 8.0000, 6.0000]])
tensor([[7.0000, 2.0000, 3.0000],
[4.0000, 5.0000, 9.0000],
[1.0000, 8.0000, 6.0000]])
Not sure if this follows PyTorch style (I’m new to PyTorch), but just want to share in case someone need a quick solution