Normalizing output from neural network loss function

Hi!

I’ve implemented a neural loss function adapted from this paper: https://arxiv.org/abs/1505.07376
However, this paper also states that:

Finally, for practical reasons, we rescaled the weights in the network such that the mean activation of each filter over images and positions is equal to one. Such re-scaling can always be done without changing the output of a neural network if the non-linearities in the network are rectifying linear.

Now I would very much like to do some similar normalization of my neural function. As of now, the output completely depends on my weights for the different layers. Is it possible to normalize my output so that it is always capped between 0 and 1.0 or is not feasible? If not, then how, at least, would I rescale the weights in the network like they did in the paper?

Here is my implementation:

class NeuralLoss(Loss):

    def __init__(self, layers: list = None, layer_weights: list = None):
        super().__init__()
        if layers is None:
            layers = [2, 4, 6]
        if layer_weights is None:
            layer_weights = [2e-1, 2e1, 2e1]

        self._layer_indices = layers
        self._layer_weights = torch.tensor(layer_weights)

        self.vgg = models.vgg19(pretrained=True, progress=True)
        self.modulelist = list(self.vgg.features.modules())

        for i, mod in enumerate(self.modulelist):
            if hasattr(mod, "inplace"):
                mod.inplace = False

            if isinstance(mod, torch.nn.MaxPool2d):
                self.modulelist[i] = torch.nn.AvgPool2d(kernel_size=mod.kernel_size, stride=mod.stride, padding=mod.padding, ceil_mode=mod.ceil_mode)

        self._normalize = T.Compose([
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def __str__(self):
        layers = "\n\t".join([self.modulelist[i + 1].__str__() for i in self._layer_indices])
        return "Neural Loss (\n\tlayers:\n {}, \n\tweights: {}\n)".format(layers, self._layer_weights)

    def _preprocess(self, x: Tensor):
        # Swap axes to get image on CxHxW form, which is required for Models in PyTorch, then Normalize to comply with VGG19 and add batch dimension
        return self._normalize(x.permute(2, 1, 0)).unsqueeze(0)

    def _gram_matrix(self, activation: Tensor, N: int, M: int):
        feature_map = activation.view(N, M)
        G = torch.mm(feature_map, feature_map.t())
        return G

    def _loss(self, x: Tensor, target: Tensor):
        assert x.shape == target.shape
        assert list(x.shape) == [224, 224, 3]

        x_ = self._preprocess(x)
        target_ = self._preprocess(target)
        E = []

        for i, layer in enumerate(self.modulelist[1:]):
            x_ = layer(x_)
            target_ = layer(target_)
            if i > self._layer_indices[-1]:
                break
            if i in self._layer_indices:
                N = x_.shape[1]
                M = x_.shape[2] * x_.shape[3]
                G1 = self._gram_matrix(x_, N, M)
                G2 = self._gram_matrix(target_, N, M)

                E.append(sse_loss(G1, G2) * (1 / (4. * (N ** 2) * (M ** 2))))

        L_tot = torch.sum(torch.stack(E) * self._layer_weights)
        return L_tot