Hi everyone, I am trying to port my dataset normalization and output normalization into a PyTorch graph. It normalizes input and denormalizes output based on the input.

```
import torch
class NormDenorm(torch.nn.Module):
def __init__(self, module):
super(NormDenorm, self).__init__()
# Indices of channels for denormalization.
self.indices = torch.tensor([0])
self.add_module('capsule', module)
def print(self, *args):
if not self.training:
print(*args)
def forward(self, x):
min_x, _ = torch.min(x, dim=-1, keepdim=True)
max_x, _ = torch.max(x, dim=-1, keepdim=True)
max_min_diff = torch.add(max_x - min_x, 1e-4)
# Normalize input for computation.
x_norm = torch.add(x - min_x, 1e-4)
x_norm = x_norm / max_min_diff
# Perform main module operations.
y = self.capsule(x_norm)
# Denormalize output based on selected input.
min_x = torch.take(min_x, self.indices)
max_x = torch.take(max_x, self.indices)
max_min_diff = torch.take(max_min_diff, self.indices)
y_denorm = y * max_min_diff + min_x
return y_denorm
model = NormDenorm(MyModel())
```

It takes arbitrary numerical data and normalizes it to [0, 1] range.

Is it generally a bad idea and if not how should I approach writing `backward`

function for this case?