Backward propagation with custom denormalization layer

Hi everyone, I am trying to port my dataset normalization and output normalization into a PyTorch graph. It normalizes input and denormalizes output based on the input.

import torch

class NormDenorm(torch.nn.Module):
    def __init__(self, module):
        super(NormDenorm, self).__init__()
        # Indices of channels for denormalization.
        self.indices = torch.tensor([0])
        self.add_module('capsule', module)

    def print(self, *args):
        if not

    def forward(self, x):
        min_x, _ = torch.min(x, dim=-1, keepdim=True)
        max_x, _ = torch.max(x, dim=-1, keepdim=True)
        max_min_diff = torch.add(max_x - min_x, 1e-4)

        # Normalize input for computation.
        x_norm = torch.add(x - min_x, 1e-4)
        x_norm = x_norm / max_min_diff

        # Perform main module operations.
        y = self.capsule(x_norm)

        # Denormalize output based on selected input.
        min_x = torch.take(min_x, self.indices)
        max_x = torch.take(max_x, self.indices)
        max_min_diff = torch.take(max_min_diff, self.indices)
        y_denorm = y * max_min_diff + min_x
        return y_denorm

model = NormDenorm(MyModel())

It takes arbitrary numerical data and normalizes it to [0, 1] range.

Is it generally a bad idea and if not how should I approach writing backward function for this case?